Exemple #1
0
 def visit_AugAssign(self, node):
     node.value = self.visit(node.value)
     if util.contains_symbol(node.target, self.loop_var):
         if not util.contains_symbol(node.target.right, self.loop_var):
             target = self.visit(deepcopy(node.target))
             curr_node = node.target
             idx = 1
             while curr_node.left.right.name != self.loop_var:
                 curr_node = curr_node.left
                 idx += 1
             curr_node.left = curr_node.left.left
             node.target = C.ArrayRef(node.target,
                                      C.SymbolRef(self.loop_var))
             while not isinstance(curr_node, C.SymbolRef):
                 curr_node = curr_node.left
             if curr_node.name in self.transposed_buffers and self.transposed_buffers[
                     curr_node.name] != idx:
                 raise NotImplementedError()
             self.transposed_buffers[curr_node.name] = idx
             curr_node.name += "_transposed"
             if isinstance(node.target.right,
                           C.Constant) and node.target.value == 0.0:
                 return store_ps(node.target.left,
                                 C.BinaryOp(target, node.op, node.value))
             else:
                 return store_ps(C.Ref(node.target),
                                 C.BinaryOp(target, node.op, node.value))
         else:
             if isinstance(node.target.right,
                           C.Constant) and node.target.value == 0.0:
                 return store_ps(
                     node.target.left,
                     C.BinaryOp(self.visit(node.target), node.op,
                                node.value))
             else:
                 return store_ps(
                     C.Ref(node.target),
                     C.BinaryOp(self.visit(node.target), node.op,
                                node.value))
     elif isinstance(node.op, C.Op.Add) and isinstance(
             node.value, C.FunctionCall):
         # TODO: Verfiy it's a vector intrinsic
         return C.Assign(
             node.target,
             C.FunctionCall(C.SymbolRef("_mm256_add_ps"),
                            [node.value, node.target]))
     elif isinstance(node.target, C.BinaryOp) and isinstance(
             node.target.op, C.Op.ArrayRef):
         raise NotImplementedError(node)
     node.target = self.visit(node.target)
     return node
    def visit_BinaryOp(self, node):
        if isinstance(node.op, C.Op.ArrayRef):
            if util.contains_symbol(node, self.loop_var):
                idx = 0
                curr_node = node
                while not isinstance(curr_node.right, C.SymbolRef) or \
                        curr_node.right.name != self.loop_var:
                    idx += 1
                    curr_node = curr_node.left
                while not isinstance(curr_node, C.SymbolRef):
                    curr_node = curr_node.left
                self.vectorized_buffers[curr_node.name] = idx
                if self.vectorize:
                    return simd_macros.mm256_load_ps(node)
                else:
                    return C.ArrayRef(node,
                                      C.SymbolRef("_neuron_index_1_inner"))
            else:
                if self.vectorize:
                    return simd_macros.mm256_set1_ps(node)
                else:
                    return node

        node.left = self.visit(node.left)
        node.right = self.visit(node.right)
        return node
 def visit_AugAssign(self, node):
     node.value = self.visit(node.value)
     if not self.vectorize:
         node.target = self.visit(node.target)
         return node
     if util.contains_symbol(node.target, self.loop_var):
         return simd_macros.mm256_store_ps(
             node.target,
             C.BinaryOp(self.visit(node.target), node.op, node.value))
     elif isinstance(node.op, C.Op.Add) and isinstance(node.value, C.BinaryOp) and \
             isinstance(node.value.op, C.Op.Mul):
         # if not isinstance(node.target, C.SymbolRef):
         #     node.value = C.FunctionCall(C.SymbolRef("vsum"), [node.value])
         #     return node
         # else:
         return C.Assign(
             node.target,
             C.FunctionCall(
                 C.SymbolRef("_mm256_fmadd_ps"),
                 [node.value.left, node.value.right, node.target]))
     elif isinstance(node.op, C.Op.Add) and isinstance(
             node.value, C.FunctionCall):
         # TODO: Verfiy it's a vector intrinsic
         return C.Assign(
             node.target,
             C.FunctionCall(C.SymbolRef("_mm256_add_ps"),
                            [node.value, node.target]))
     elif isinstance(node.target, C.BinaryOp) and isinstance(
             node.target.op, C.Op.ArrayRef):
         raise NotImplementedError()
     node.target = self.visit(node.target)
     return node
Exemple #4
0
    def visit_If(self, node):
        check = [
            util.contains_symbol(node, var)
            for var in list(self.unrolled_vars) + [self.target_var]
        ]

        if any(check):
            body = []
            for i in range(self.factor):
                stmt = deepcopy(node)
                for var in self.unrolled_vars:
                    stmt = util.replace_symbol(var,
                                               C.SymbolRef(var + "_" + str(i)),
                                               stmt)
                if self.unroll_type == 0:
                    body.append(
                        util.replace_symbol(
                            self.target_var,
                            C.Add(C.SymbolRef(self.target_var), C.Constant(i)),
                            stmt))
                elif self.unroll_type == 1:
                    body.append(
                        util.replace_symbol(
                            self.target_var,
                            C.Add(
                                C.Mul(C.Constant(self.factor),
                                      C.SymbolRef(self.target_var)),
                                C.Constant(i)), stmt))
                else:
                    assert (false)

            return body
        return node
Exemple #5
0
 def visit_BinaryOp(self, node):
     if isinstance(node.op, C.Op.Assign):
         check = [
             util.contains_symbol(node.right, var)
             for var in list(self.unrolled_vars) + [self.target_var]
         ]
         if any(check):
             body = []
             if hasattr(node.left, 'type') and node.left.type is not None:
                 self.unrolled_vars.add(node.left.name)
             for i in range(self.factor):
                 stmt = deepcopy(node)
                 for var in self.unrolled_vars:
                     stmt = util.replace_symbol(
                         var, C.SymbolRef(var + "_" + str(i)), stmt)
                 if self.unroll_type == 0:
                     body.append(
                         util.replace_symbol(
                             self.target_var,
                             C.Add(C.SymbolRef(self.target_var),
                                   C.Constant(i)), stmt))
                 elif self.unroll_type == 1:
                     body.append(
                         util.replace_symbol(
                             self.target_var,
                             C.Add(
                                 C.Mul(C.Constant(self.factor),
                                       C.SymbolRef(self.target_var)),
                                 C.Constant(i)), stmt))
                 else:
                     assert (false)
             return body
     return node
Exemple #6
0
 def visit_BinaryOp(self, node):
     node.left = self.visit(node.left)
     node.right = self.visit(node.right)
     if isinstance(node.op, C.Op.ArrayRef) and isinstance(
             node.left, C.BinaryOp):
         for curr_index, var in enumerate(self.loop_vars):
             if util.contains_symbol(node.right, var):
                 break
         for left_index, var in enumerate(self.loop_vars):
             if util.contains_symbol(node.left.right, var):
                 break
         if curr_index < left_index:
             node.left.right, node.right = node.right, node.left.right
             node.left = self.visit(node.left)
             return node
     return node
Exemple #7
0
 def visit(self, node):
     node = super().visit(node)
     if hasattr(node, 'body'):
         new_body = []
         for stmt in reversed(node.body):
             if isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \
                     isinstance(stmt.right, C.FunctionCall) and stmt.right.func.name in ["_mm256_broadcast_ss"]:
                 value = stmt.left.name
                 for i in range(len(new_body)):
                     if util.contains_symbol(new_body[i], value):
                         new_body.insert(i, stmt)
                         break
             else:
                 new_body.insert(0, stmt)
         node.body = new_body
     return node
Exemple #8
0
 def visit_For(self, node):
     node.body = [self.visit(s) for s in node.body]
     pre_stmts = []
     loads = []
     rest = []
     for stmt in node.body:
         if not hasattr(stmt, 'body'):
             if util.contains_symbol(stmt, "_mm256_load_ps"):
                 loads.append(stmt)
             elif isinstance(stmt, C.BinaryOp) and isinstance(
                     stmt.op, C.Op.Assign) and isinstance(
                         stmt.left,
                         C.SymbolRef) and stmt.left.type is not None:
                 pre_stmts.append(stmt)
             else:
                 rest.append(stmt)
         else:
             rest.append(stmt)
     node.body = pre_stmts + loads + rest
     return node
Exemple #9
0
 def visit_For(self, node):
     node.body = util.flatten([self.visit(s) for s in node.body])
     if node.init.left.name == "_neuron_index_0":
         #Don't lift out of outer most loop
         return node
     pre_stmts = []
     new_body = []
     post_stmts = []
     loop_var = node.init.left.name
     deps = set()
     for stmt in node.body:
         # print(astor.dump_tree(stmt))
         if isinstance(stmt, C.FunctionCall) and "_mm" in stmt.func.name and \
             "_store" in stmt.func.name and \
             not util.contains_symbol(stmt, loop_var) and \
             not any(util.contains_symbol(stmt, dep) for dep in deps):
             post_stmts.append(stmt)
         elif isinstance(stmt, C.BinaryOp) and isinstance(stmt.op, C.Op.Assign) and \
                 isinstance(stmt.right, C.FunctionCall) and "_load" in stmt.right.func.name and \
                 not util.contains_symbol(stmt, loop_var) and \
                 not any(util.contains_symbol(stmt, dep) for dep in deps):
             pre_stmts.append(stmt)
         elif isinstance(stmt, C.BinaryOp) and \
              isinstance(stmt.op, C.Op.Assign) and \
              isinstance(stmt.left, C.SymbolRef) and \
              stmt.left.type is not None and \
                 not util.contains_symbol(stmt, loop_var) and \
                 not any(util.contains_symbol(stmt, dep) for dep in deps):
             pre_stmts.append(stmt)
         else:
             new_body.append(stmt)
             if isinstance(stmt, C.BinaryOp) and \
                isinstance(stmt.op, C.Op.Assign) and \
                isinstance(stmt.left, C.SymbolRef) and \
                stmt.left.type is not None:
                 deps.add(stmt.left.name)
     node.body = new_body
     return pre_stmts + [node] + post_stmts
Exemple #10
0
    def visit_BinaryOp(self, node):
        if isinstance(node.op, C.Op.ArrayRef):
            if util.contains_symbol(node, self.loop_var):
                if not util.contains_symbol(node.right, self.loop_var):
                    curr_node = node
                    idx = 1
                    while curr_node.left.right.name != self.loop_var:
                        curr_node = curr_node.left
                        idx += 1
                    curr_node.left = curr_node.left.left
                    node = C.ArrayRef(node, C.SymbolRef(self.loop_var))
                    while not isinstance(curr_node, C.SymbolRef):
                        curr_node = curr_node.left
                    if curr_node.name in self.transposed_buffers and self.transposed_buffers[
                            curr_node.name] != idx:
                        raise NotImplementedError()
                    self.transposed_buffers[curr_node.name] = idx
                    curr_node.name += "_transposed"
                if isinstance(node.right,
                              C.Constant) and node.target.value == 0.0:
                    return load_ps(node.left)
                else:
                    return load_ps(C.Ref(node))
            else:
                return broadcast_ss(C.Ref(node))
        elif isinstance(node.op, C.Op.Assign):
            node.right = self.visit(node.right)
            if isinstance(node.right, C.FunctionCall) and \
                    ("load_ps" in node.right.func.name or
                     "broadcast_ss" in node.right.func.name) and \
                    isinstance(node.left, C.SymbolRef) and node.left.type is not None:
                node.left.type = get_simd_type()()
                self.symbol_table[node.left.name] = node.left.type
                return node
            elif isinstance(node.left, C.BinaryOp) and util.contains_symbol(
                    node.left, self.loop_var):
                if node.left.right.name != self.loop_var:
                    curr_node = node
                    idx = 1
                    while curr_node.left.right.name != self.loop_var:
                        curr_node = curr_node.left
                        idx += 1
                    curr_node.left = curr_node.left.left
                    node = C.ArrayRef(node, C.SymbolRef(self.loop_var))
                    while not isinstance(curr_node, C.SymbolRef):
                        curr_node = curr_node.left
                    if curr_node.name in self.transposed_buffers and self.transposed_buffers[
                            curr_node.name] != idx:
                        raise NotImplementedError()
                    self.transposed_buffers[curr_node.name] = idx
                    curr_node.name += "_transposed"

                is_float = self.get_type(node.left)

                if isinstance(is_float, ctypes.c_float):
                    if isinstance(node.left.right,
                                  C.Constant) and node.target.value == 0.0:
                        return store_ps(node.left.left, node.right)
                    else:
                        return store_ps(C.Ref(node.left), node.right)
                elif isinstance(is_float, ctypes.c_int):
                    if isinstance(node.left.right,
                                  C.Constant) and node.target.value == 0.0:
                        return store_epi32(node.left.left, node.right)
                    else:
                        return store_epi32(C.Ref(node.left), node.right)
                else:
                    if isinstance(node.left.right,
                                  C.Constant) and node.target.value == 0.0:
                        return store_ps(node.left.left, node.right)
                    else:
                        return store_ps(C.Ref(node.left), node.right)

            node.left = self.visit(node.left)
            return node
        node.left = self.visit(node.left)
        node.right = self.visit(node.right)
        return node
Exemple #11
0
    def visit_AugAssign(self, node):
        check = [
            util.contains_symbol(node.value, var)
            for var in list(self.unrolled_vars) + [self.target_var]
        ]
        if any(check):
            body = []
            if isinstance(node.target, C.SymbolRef):
                self.unrolled_vars.add(self._get_name(node.target.name))
                for i in range(self.factor):
                    stmt = deepcopy(node)
                    for var in self.unrolled_vars:
                        stmt = util.replace_symbol(
                            var, C.SymbolRef(var + "_" + str(i)), stmt)
                    #body.append(util.replace_symbol(self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt))
                    if self.unroll_type == 0:
                        body.append(
                            util.replace_symbol(
                                self.target_var,
                                C.Add(C.SymbolRef(self.target_var),
                                      C.Constant(i)), stmt))
                    elif self.unroll_type == 1:
                        body.append(
                            util.replace_symbol(
                                self.target_var,
                                C.Add(
                                    C.Mul(C.Constant(self.factor),
                                          C.SymbolRef(self.target_var)),
                                    C.Constant(i)), stmt))
                    else:
                        assert (false)

                return body
            elif isinstance(node.target, C.BinaryOp) and isinstance(
                    node.target.op, C.Op.ArrayRef):
                assert False
                for i in range(self.factor):
                    stmt = deepcopy(node)
                    for var in self.unrolled_vars:
                        stmt = util.replace_symbol(
                            var, C.SymbolRef(var + "_" + str(i)), stmt)
                    #body.append(util.replace_symbol(self.target_var, C.Add(C.SymbolRef(self.target_var), C.Constant(i)), stmt))

                    if self.unroll_type == 0:
                        body.append(
                            util.replace_symbol(
                                self.target_var,
                                C.Add(C.SymbolRef(self.target_var),
                                      C.Constant(i)), stmt))
                    elif self.unroll_type == 1:
                        body.append(
                            util.replace_symbol(
                                self.target_var,
                                C.Add(
                                    C.Mul(C.Constant(self.factor),
                                          C.SymbolRef(self.target_var)),
                                    C.Constant(i)), stmt))
                    else:
                        assert (false)

                return body
            else:
                raise NotImplementedError()
        return node