Exemplo n.º 1
0
    def copy_memory(self, sdfg: SDFG, dfg: SDFGState, state_id: int, src_node: nodes.Node, dst_node: nodes.Node,
                    edge: gr.MultiConnectorEdge[mm.Memlet], function_stream: CodeIOStream,
                    callsite_stream: CodeIOStream) -> None:

        # Check whether it is a known reduction that is possible in SVE
        reduction_type = detect_reduction_type(edge.data.wcr)
        if reduction_type not in util.REDUCTION_TYPE_TO_SVE:
            raise util.NotSupportedError('Unsupported reduction in SVE')

        nc = not is_write_conflicted(dfg, edge)
        desc = edge.src.desc(sdfg)
        if not nc or not isinstance(desc.dtype, (dtypes.pointer, dtypes.vector)):
            # WCR on vectors works in two steps:
            # 1. Reduce the SVE register using SVE instructions into a scalar
            # 2. WCR the scalar to memory using DaCe functionality
            wcr = self.cpu_codegen.write_and_resolve_expr(sdfg, edge.data, not nc, None, '@', dtype=desc.dtype)
            callsite_stream.write(wcr[:wcr.find('@')] + util.REDUCTION_TYPE_TO_SVE[reduction_type] +
                                  f'(svptrue_{util.TYPE_TO_SVE_SUFFIX[desc.dtype]}(), ' + src_node.label +
                                  wcr[wcr.find('@') + 1:] + ');')
            return
        else:
            ######################
            # Horizontal non-atomic reduction
            raise NotImplementedError()

        return super().copy_memory(sdfg, dfg, state_id, src_node, dst_node, edge, function_stream, callsite_stream)
Exemplo n.º 2
0
    def vector_reduction_expr(self, edge, dtype, rhs):
        # Check whether it is a known reduction that is possible in SVE
        reduction_type = detect_reduction_type(edge.data.wcr)
        if reduction_type not in util.REDUCTION_TYPE_TO_SVE:
            raise util.NotSupportedError('Unsupported reduction in SVE')

        nc = not is_write_conflicted(self.dfg, edge)
        if not nc or not isinstance(edge.src.out_connectors[edge.src_conn],
                                    (dtypes.pointer, dtypes.vector)):
            # WCR on vectors works in two steps:
            # 1. Reduce the SVE register using SVE instructions into a scalar
            # 2. WCR the scalar to memory using DaCe functionality
            dst_node = self.dfg.memlet_path(edge)[-1].dst
            if (isinstance(dst_node, nodes.AccessNode) and dst_node.desc(
                    self.sdfg).storage == dtypes.StorageType.SVE_Register):
                return

            wcr = self.cpu_codegen.write_and_resolve_expr(self.sdfg,
                                                          edge.data,
                                                          not nc,
                                                          None,
                                                          '@',
                                                          dtype=dtype)
            self.fill(wcr[:wcr.find('@')])
            self.write(util.REDUCTION_TYPE_TO_SVE[reduction_type])
            self.write('(')
            self.write(self.pred_name)
            self.write(', ')
            self.dispatch_expect(rhs, dtypes.vector(dtype, -1))
            self.write(')')
            self.write(wcr[wcr.find('@') + 1:])
            self.write(';')
        else:
            ######################
            # Horizontal non-atomic reduction

            stride = edge.data.get_stride(self.sdfg, self.map)

            # long long fix
            ptr_cast = ''
            src_type = edge.src.out_connectors[edge.src_conn]

            if src_type.type == np.int64:
                ptr_cast = '(int64_t*) '
            elif src_type.type == np.uint64:
                ptr_cast = '(uint64_t*) '

            store_args = '{}, {}'.format(
                self.pred_name,
                ptr_cast +
                cpp_ptr_expr(self.sdfg, edge.data, DefinedType.Pointer),
            )

            red_type = util.REDUCTION_TYPE_TO_SVE[reduction_type][:-1] + '_x'
            if stride == 1:
                self.write(
                    f'svst1({store_args}, {red_type}({self.pred_name}, svld1({store_args}), '
                )
                self.dispatch_expect(rhs, dtypes.vector(dtype, -1))
                self.write('));')
            else:
                store_args = f'{store_args}, svindex_s{util.get_base_type(src_type).bytes * 8}(0, {sym2cpp(stride)})'
                self.write(
                    f'svst1_scatter_index({store_args}, {red_type}({self.pred_name}, svld1_gather_index({store_args}), '
                )
                self.dispatch_expect(rhs, dtypes.vector(dtype, -1))
                self.write('));')