示例#1
0
    def apply(self, sdfg: dace.SDFG):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[self._map_entry]]
        map_exit = graph.exit_node(map_entry)

        sz = dace.symbol('commsize',
                         dtype=dace.int32,
                         integer=True,
                         positive=True)
        Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True)
        Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True)

        def _prod(sequence):
            return reduce(lambda a, b: a * b, sequence, 1)

        # NOTE: Maps with step in their ranges are currently not supported
        if len(map_entry.map.params) == 2:
            params = map_entry.map.params
            ranges = [None] * 2
            b, e, _ = map_entry.map.range[0]
            ranges[0] = (0, (e - b + 1) / Px - 1, 1)
            b, e, _ = map_entry.map.range[1]
            ranges[1] = (0, (e - b + 1) / Py - 1, 1)
            strides = [1]
        else:
            params = ['__iflat']
            sizes = map_entry.map.range.size_exact()
            total_size = _prod(sizes)
            ranges = [(0, (total_size) / sz - 1, 1)]
            strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))]

        root_name = sdfg.temp_data_name()
        sdfg.add_scalar(root_name, dace.int32, transient=True)
        root_node = graph.add_access(root_name)
        root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'},
                                         '__out = 0')
        graph.add_edge(root_tasklet, '__out', root_node, None,
                       dace.Memlet.simple(root_name, '0'))

        from dace.libraries.mpi import Bcast
        from dace.libraries.pblas import BlockCyclicScatter, BlockCyclicGather

        inputs = set()
        for src, _, _, _, m in graph.in_edges(map_entry):
            if not isinstance(src, nodes.AccessNode):
                raise NotImplementedError
            desc = src.desc(sdfg)
            if not isinstance(desc, (data.Scalar, data.Array)):
                raise NotImplementedError
            if list(desc.shape) != m.src_subset.size_exact():
                # Second attempt
                # TODO: We need a solution for symbols not matching
                if str(list(desc.shape)) != str(m.src_subset.size_exact()):
                    raise NotImplementedError
            inputs.add(src)

        for inp in inputs:
            desc = inp.desc(sdfg)

            if isinstance(desc, data.Scalar):
                local_access = graph.add_access(inp.data)
                bcast_node = Bcast('_Bcast_')
                graph.add_edge(inp, None, bcast_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(root_node, None, bcast_node, '_root',
                               dace.Memlet.simple(root_name, '0'))
                graph.add_edge(bcast_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(inp.data, desc))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(local_access, None, map_entry, e.dst_conn,
                                   dace.Memlet.from_array(inp.data, desc))
                    graph.remove_edge(e)

            elif isinstance(desc, data.Array):

                local_name, local_arr = sdfg.add_temp_transient(
                    [(desc.shape[0]) // Px, (desc.shape[1]) // Py],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                bsizes_name, bsizes_arr = sdfg.add_temp_transient(
                    (2, ), dtype=dace.int32)
                bsizes_access = graph.add_access(bsizes_name)
                bsizes_tasklet = nodes.Tasklet(
                    '_set_bsizes_', {}, {'__out'},
                    "__out[0] = {x}; __out[1] = {y}".format(
                        x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py))
                graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None,
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                gdesc_name, gdesc_arr = sdfg.add_temp_transient(
                    (9, ), dtype=dace.int32)
                gdesc_access = graph.add_access(gdesc_name)
                ldesc_name, ldesc_arr = sdfg.add_temp_transient(
                    (9, ), dtype=dace.int32)
                ldesc_access = graph.add_access(ldesc_name)
                scatter_node = BlockCyclicScatter('_Scatter_')
                graph.add_edge(inp, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(bsizes_access, None, scatter_node,
                               '_block_sizes',
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                graph.add_edge(scatter_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(local_name, local_arr))
                graph.add_edge(scatter_node, '_gdescriptor', gdesc_access,
                               None,
                               dace.Memlet.from_array(gdesc_name, gdesc_arr))
                graph.add_edge(scatter_node, '_ldescriptor', ldesc_access,
                               None,
                               dace.Memlet.from_array(ldesc_name, ldesc_arr))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(
                        local_access, None, map_entry, e.dst_conn,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.out_edges(map_entry):
                    if e.data.data == inp.data:
                        e.data.data = local_name

            else:
                raise NotImplementedError

        outputs = set()
        for _, _, dst, _, m in graph.out_edges(map_exit):
            if not isinstance(dst, nodes.AccessNode):
                raise NotImplementedError
            desc = dst.desc(sdfg)
            if not isinstance(desc, data.Array):
                raise NotImplementedError
            try:
                if list(desc.shape) != m.dst_subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.dst_subset.size_exact()):
                        raise NotImplementedError
            except AttributeError:
                if list(desc.shape) != m.subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.subset.size_exact()):
                        raise NotImplementedError
            outputs.add(dst)

        for out in outputs:
            desc = out.desc(sdfg)
            if isinstance(desc, data.Scalar):
                raise NotImplementedError
            elif isinstance(desc, data.Array):
                local_name, local_arr = sdfg.add_temp_transient(
                    [(desc.shape[0]) // Px, (desc.shape[1]) // Py],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                bsizes_name, bsizes_arr = sdfg.add_temp_transient(
                    (2, ), dtype=dace.int32)
                bsizes_access = graph.add_access(bsizes_name)
                bsizes_tasklet = nodes.Tasklet(
                    '_set_bsizes_', {}, {'__out'},
                    "__out[0] = {x}; __out[1] = {y}".format(
                        x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py))
                graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None,
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                scatter_node = BlockCyclicGather('_Gather_')
                graph.add_edge(local_access, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(local_name, local_arr))
                graph.add_edge(bsizes_access, None, scatter_node,
                               '_block_sizes',
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                graph.add_edge(scatter_node, '_outbuffer', out, None,
                               dace.Memlet.from_array(out.data, desc))

                for e in graph.edges_between(map_exit, out):
                    graph.add_edge(
                        map_exit, e.src_conn, local_access, None,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.in_edges(map_exit):
                    if e.data.data == out.data:
                        e.data.data = local_name
            else:
                raise NotImplementedError

        map_entry.map.params = params
        map_entry.map.range = subsets.Range(ranges)
示例#2
0
    def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
        map_entry = self.map_entry
        map_exit = graph.exit_node(map_entry)

        sz = dace.symbol('commsize', dtype=dace.int32)

        def _prod(sequence):
            return reduce(lambda a, b: a * b, sequence, 1)

        # NOTE: Maps with step in their ranges are currently not supported
        if len(map_entry.map.params) == 1:
            params = map_entry.map.params
            ranges = [(0, (e - b + 1) / sz - 1, 1)
                      for b, e, _ in map_entry.map.range]
            strides = [1]
        else:
            params = ['__iflat']
            sizes = map_entry.map.range.size_exact()
            total_size = _prod(sizes)
            ranges = [(0, (total_size) / sz - 1, 1)]
            strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))]

        root_name = sdfg.temp_data_name()
        sdfg.add_scalar(root_name, dace.int32, transient=True)
        root_node = graph.add_access(root_name)
        root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'},
                                         '__out = 0')
        graph.add_edge(root_tasklet, '__out', root_node, None,
                       dace.Memlet.simple(root_name, '0'))

        from dace.libraries.mpi import Bcast, Scatter, Gather

        inputs = set()
        for src, _, _, _, m in graph.in_edges(map_entry):
            if not isinstance(src, nodes.AccessNode):
                raise NotImplementedError
            desc = src.desc(sdfg)
            if not isinstance(desc, (data.Scalar, data.Array)):
                raise NotImplementedError
            if list(desc.shape) != m.src_subset.size_exact():
                # Second attempt
                # TODO: We need a solution for symbols not matching
                if str(list(desc.shape)) != str(m.src_subset.size_exact()):
                    raise NotImplementedError
            inputs.add(src)

        for inp in inputs:
            desc = inp.desc(sdfg)

            if isinstance(desc, data.Scalar):
                local_access = graph.add_access(inp.data)
                bcast_node = Bcast('_Bcast_')
                graph.add_edge(inp, None, bcast_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(root_node, None, bcast_node, '_root',
                               dace.Memlet.simple(root_name, '0'))
                graph.add_edge(bcast_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(inp.data, desc))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(local_access, None, map_entry, e.dst_conn,
                                   dace.Memlet.from_array(inp.data, desc))
                    graph.remove_edge(e)

            elif isinstance(desc, data.Array):

                local_name, local_arr = sdfg.add_temp_transient(
                    [sympy.floor(desc.total_size / sz)],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                scatter_node = Scatter('_Scatter_')
                graph.add_edge(inp, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(root_node, None, scatter_node, '_root',
                               dace.Memlet.simple(root_name, '0'))
                graph.add_edge(scatter_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(local_name, local_arr))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(
                        local_access, None, map_entry, e.dst_conn,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.out_edges(map_entry):
                    if e.data.data == inp.data:
                        e.data = dace.Memlet.simple(local_name, params[0])

            else:
                raise NotImplementedError

        outputs = set()
        for _, _, dst, _, m in graph.out_edges(map_exit):
            if not isinstance(dst, nodes.AccessNode):
                raise NotImplementedError
            desc = dst.desc(sdfg)
            if not isinstance(desc, data.Array):
                raise NotImplementedError
            try:
                if list(desc.shape) != m.dst_subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.dst_subset.size_exact()):
                        raise NotImplementedError
            except AttributeError:
                if list(desc.shape) != m.subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.subset.size_exact()):
                        raise NotImplementedError
            outputs.add(dst)

        for out in outputs:
            desc = out.desc(sdfg)
            if isinstance(desc, data.Scalar):
                raise NotImplementedError
            elif isinstance(desc, data.Array):
                local_name, local_arr = sdfg.add_temp_transient(
                    [sympy.floor(desc.total_size / sz)],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                scatter_node = Gather('_Gather_')
                graph.add_edge(local_access, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(local_name, local_arr))
                graph.add_edge(root_node, None, scatter_node, '_root',
                               dace.Memlet.simple(root_name, '0'))
                graph.add_edge(scatter_node, '_outbuffer', out, None,
                               dace.Memlet.from_array(out.data, desc))
                for e in graph.edges_between(map_exit, out):
                    graph.add_edge(
                        map_exit, e.src_conn, local_access, None,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.in_edges(map_exit):
                    if e.data.data == out.data:
                        e.data = dace.Memlet.simple(local_name, params[0])
            else:
                raise NotImplementedError

        map_entry.map.params = params
        map_entry.map.range = subsets.Range(ranges)