def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): map_entry = self.map_entry map_exit = graph.exit_node(map_entry) sz = dace.symbol('commsize', dtype=dace.int32, integer=True, positive=True) Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True) Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True) from dace.data import _prod # NOTE: Maps with step in their ranges are currently not supported if len(map_entry.map.params) == 2: params = map_entry.map.params ranges = [None] * 2 b, e, _ = map_entry.map.range[0] ranges[0] = (0, (e - b + 1) / Px - 1, 1) b, e, _ = map_entry.map.range[1] ranges[1] = (0, (e - b + 1) / Py - 1, 1) strides = [1] else: params = ['__iflat'] sizes = map_entry.map.range.size_exact() total_size = _prod(sizes) ranges = [(0, (total_size) / sz - 1, 1)] strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))] root_name = sdfg.temp_data_name() sdfg.add_scalar(root_name, dace.int32, transient=True) root_node = graph.add_access(root_name) root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'}, '__out = 0') graph.add_edge(root_tasklet, '__out', root_node, None, dace.Memlet.simple(root_name, '0')) from dace.libraries.mpi import Bcast from dace.libraries.pblas import BlockCyclicScatter, BlockCyclicGather inputs = set() for src, _, _, _, m in graph.in_edges(map_entry): if not isinstance(src, nodes.AccessNode): raise NotImplementedError desc = src.desc(sdfg) if not isinstance(desc, (data.Scalar, data.Array)): raise NotImplementedError if list(desc.shape) != m.src_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.src_subset.size_exact()): raise NotImplementedError inputs.add(src) for inp in inputs: desc = inp.desc(sdfg) if isinstance(desc, data.Scalar): local_access = graph.add_access(inp.data) bcast_node = Bcast('_Bcast_') graph.add_edge(inp, None, bcast_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(root_node, None, bcast_node, '_root', dace.Memlet.simple(root_name, '0')) graph.add_edge(bcast_node, '_outbuffer', local_access, None, dace.Memlet.from_array(inp.data, desc)) for e in graph.edges_between(inp, map_entry): graph.add_edge(local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(inp.data, desc)) graph.remove_edge(e) elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [(desc.shape[0]) // Px, (desc.shape[1]) // Py], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) bsizes_name, bsizes_arr = sdfg.add_temp_transient( (2, ), dtype=dace.int32) bsizes_access = graph.add_access(bsizes_name) bsizes_tasklet = nodes.Tasklet( '_set_bsizes_', {}, {'__out'}, "__out[0] = {x}; __out[1] = {y}".format( x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py)) graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None, dace.Memlet.from_array(bsizes_name, bsizes_arr)) gdesc_name, gdesc_arr = sdfg.add_temp_transient( (9, ), dtype=dace.int32) gdesc_access = graph.add_access(gdesc_name) ldesc_name, ldesc_arr = sdfg.add_temp_transient( (9, ), dtype=dace.int32) ldesc_access = graph.add_access(ldesc_name) scatter_node = BlockCyclicScatter('_Scatter_') graph.add_edge(inp, None, scatter_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(bsizes_access, None, scatter_node, '_block_sizes', dace.Memlet.from_array(bsizes_name, bsizes_arr)) graph.add_edge(scatter_node, '_outbuffer', local_access, None, dace.Memlet.from_array(local_name, local_arr)) graph.add_edge(scatter_node, '_gdescriptor', gdesc_access, None, dace.Memlet.from_array(gdesc_name, gdesc_arr)) graph.add_edge(scatter_node, '_ldescriptor', ldesc_access, None, dace.Memlet.from_array(ldesc_name, ldesc_arr)) for e in graph.edges_between(inp, map_entry): graph.add_edge( local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.out_edges(map_entry): if e.data.data == inp.data: e.data.data = local_name else: raise NotImplementedError outputs = set() for _, _, dst, _, m in graph.out_edges(map_exit): if not isinstance(dst, nodes.AccessNode): raise NotImplementedError desc = dst.desc(sdfg) if not isinstance(desc, data.Array): raise NotImplementedError try: if list(desc.shape) != m.dst_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.dst_subset.size_exact()): raise NotImplementedError except AttributeError: if list(desc.shape) != m.subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.subset.size_exact()): raise NotImplementedError outputs.add(dst) for out in outputs: desc = out.desc(sdfg) if isinstance(desc, data.Scalar): raise NotImplementedError elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [(desc.shape[0]) // Px, (desc.shape[1]) // Py], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) bsizes_name, bsizes_arr = sdfg.add_temp_transient( (2, ), dtype=dace.int32) bsizes_access = graph.add_access(bsizes_name) bsizes_tasklet = nodes.Tasklet( '_set_bsizes_', {}, {'__out'}, "__out[0] = {x}; __out[1] = {y}".format( x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py)) graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None, dace.Memlet.from_array(bsizes_name, bsizes_arr)) scatter_node = BlockCyclicGather('_Gather_') graph.add_edge(local_access, None, scatter_node, '_inbuffer', dace.Memlet.from_array(local_name, local_arr)) graph.add_edge(bsizes_access, None, scatter_node, '_block_sizes', dace.Memlet.from_array(bsizes_name, bsizes_arr)) graph.add_edge(scatter_node, '_outbuffer', out, None, dace.Memlet.from_array(out.data, desc)) for e in graph.edges_between(map_exit, out): graph.add_edge( map_exit, e.src_conn, local_access, None, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.in_edges(map_exit): if e.data.data == out.data: e.data.data = local_name else: raise NotImplementedError map_entry.map.params = params map_entry.map.range = subsets.Range(ranges)
def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): map_entry = self.map_entry map_exit = graph.exit_node(map_entry) sz = dace.symbol('commsize', dtype=dace.int32) def _prod(sequence): return reduce(lambda a, b: a * b, sequence, 1) # NOTE: Maps with step in their ranges are currently not supported if len(map_entry.map.params) == 1: params = map_entry.map.params ranges = [(0, (e - b + 1) / sz - 1, 1) for b, e, _ in map_entry.map.range] strides = [1] else: params = ['__iflat'] sizes = map_entry.map.range.size_exact() total_size = _prod(sizes) ranges = [(0, (total_size) / sz - 1, 1)] strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))] root_name = sdfg.temp_data_name() sdfg.add_scalar(root_name, dace.int32, transient=True) root_node = graph.add_access(root_name) root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'}, '__out = 0') graph.add_edge(root_tasklet, '__out', root_node, None, dace.Memlet.simple(root_name, '0')) from dace.libraries.mpi import Bcast, Scatter, Gather inputs = set() for src, _, _, _, m in graph.in_edges(map_entry): if not isinstance(src, nodes.AccessNode): raise NotImplementedError desc = src.desc(sdfg) if not isinstance(desc, (data.Scalar, data.Array)): raise NotImplementedError if list(desc.shape) != m.src_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.src_subset.size_exact()): raise NotImplementedError inputs.add(src) for inp in inputs: desc = inp.desc(sdfg) if isinstance(desc, data.Scalar): local_access = graph.add_access(inp.data) bcast_node = Bcast('_Bcast_') graph.add_edge(inp, None, bcast_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(root_node, None, bcast_node, '_root', dace.Memlet.simple(root_name, '0')) graph.add_edge(bcast_node, '_outbuffer', local_access, None, dace.Memlet.from_array(inp.data, desc)) for e in graph.edges_between(inp, map_entry): graph.add_edge(local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(inp.data, desc)) graph.remove_edge(e) elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [sympy.floor(desc.total_size / sz)], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) scatter_node = Scatter('_Scatter_') graph.add_edge(inp, None, scatter_node, '_inbuffer', dace.Memlet.from_array(inp.data, desc)) graph.add_edge(root_node, None, scatter_node, '_root', dace.Memlet.simple(root_name, '0')) graph.add_edge(scatter_node, '_outbuffer', local_access, None, dace.Memlet.from_array(local_name, local_arr)) for e in graph.edges_between(inp, map_entry): graph.add_edge( local_access, None, map_entry, e.dst_conn, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.out_edges(map_entry): if e.data.data == inp.data: e.data = dace.Memlet.simple(local_name, params[0]) else: raise NotImplementedError outputs = set() for _, _, dst, _, m in graph.out_edges(map_exit): if not isinstance(dst, nodes.AccessNode): raise NotImplementedError desc = dst.desc(sdfg) if not isinstance(desc, data.Array): raise NotImplementedError try: if list(desc.shape) != m.dst_subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.dst_subset.size_exact()): raise NotImplementedError except AttributeError: if list(desc.shape) != m.subset.size_exact(): # Second attempt # TODO: We need a solution for symbols not matching if str(list(desc.shape)) != str(m.subset.size_exact()): raise NotImplementedError outputs.add(dst) for out in outputs: desc = out.desc(sdfg) if isinstance(desc, data.Scalar): raise NotImplementedError elif isinstance(desc, data.Array): local_name, local_arr = sdfg.add_temp_transient( [sympy.floor(desc.total_size / sz)], dtype=desc.dtype, storage=desc.storage) local_access = graph.add_access(local_name) scatter_node = Gather('_Gather_') graph.add_edge(local_access, None, scatter_node, '_inbuffer', dace.Memlet.from_array(local_name, local_arr)) graph.add_edge(root_node, None, scatter_node, '_root', dace.Memlet.simple(root_name, '0')) graph.add_edge(scatter_node, '_outbuffer', out, None, dace.Memlet.from_array(out.data, desc)) for e in graph.edges_between(map_exit, out): graph.add_edge( map_exit, e.src_conn, local_access, None, dace.Memlet.from_array(local_name, local_arr)) graph.remove_edge(e) for e in graph.in_edges(map_exit): if e.data.data == out.data: e.data = dace.Memlet.simple(local_name, params[0]) else: raise NotImplementedError map_entry.map.params = params map_entry.map.range = subsets.Range(ranges)