def _create_einsum_internal(sdfg: SDFG, state: SDFGState, einsum_string: str, *arrays: str, dtype: Optional[dtypes.typeclass] = None, optimize: bool = False, output: Optional[str] = None, nodes: Optional[Dict[str, AccessNode]] = None, init_output: bool = None): # Infer shapes and strides of input/output arrays einsum = EinsumParser(einsum_string) if len(einsum.inputs) != len(arrays): raise ValueError('Invalid number of arrays for einsum expression') # Get shapes from arrays and verify dimensionality chardict = {} for inp, inpname in zip(einsum.inputs, arrays): inparr = sdfg.arrays[inpname] if len(inp) != len(inparr.shape): raise ValueError('Dimensionality mismatch in input "%s"' % inpname) for char, shp in zip(inp, inparr.shape): if char in chardict and shp != chardict[char]: raise ValueError('Dimension mismatch in einsum expression') chardict[char] = shp if optimize: # Try to import opt_einsum try: import opt_einsum as oe except (ModuleNotFoundError, NameError, ImportError): raise ImportError('To optimize einsum expressions, please install ' 'the "opt_einsum" package.') for char, shp in chardict.items(): if symbolic.issymbolic(shp): raise ValueError('Einsum optimization cannot be performed ' 'on symbolically-sized array dimension "%s" ' 'for subscript character "%s"' % (shp, char)) # Create optimal contraction path # noinspection PyTypeChecker _, path_info = oe.contract_path( einsum_string, *oe.helpers.build_views(einsum_string, chardict)) input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} result_node = None # Follow path and create a chain of operation SDFG states for pair, nonfree, expr, after, blas in path_info.contraction_list: result, result_node = _create_einsum_internal(sdfg, state, expr, arrays[pair[0]], arrays[pair[1]], dtype=dtype, optimize=False, output=None, nodes=input_nodes) arrays = ([a for i, a in enumerate(arrays) if i not in pair] + [result]) input_nodes[result] = result_node return arrays[0], result_node # END of einsum optimization input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} # Get output shape from chardict, or [1] for a scalar output output_shape = list(map(lambda k: chardict[k], einsum.output)) or [1] output_index = ','.join(o for o in einsum.output) or '0' if output is None: dtype = dtype or sdfg.arrays[arrays[0]].dtype output, odesc = sdfg.add_temp_transient(output_shape, dtype) to_init = True else: odesc = sdfg.arrays[output] dtype = dtype or odesc.dtype to_init = init_output or True is_conflicted = not all( all(indim in einsum.output for indim in inp) for inp in einsum.inputs) if not is_conflicted and init_output is None: to_init = False if not einsum.is_bmm(): # Fall back to "pure" SDFG einsum with conflict resolution c = state.add_write(output) # Add state before this one to initialize the output value if to_init: init_state = sdfg.add_state_before(state) if len(einsum.output) > 0: init_state.add_mapped_tasklet( 'einsum_reset', {k: '0:%s' % chardict[k] for k in einsum.output}, {}, 'out_%s = 0' % output, {'out_%s' % output: Memlet.simple(output, output_index)}, external_edges=True) else: # Scalar output t = init_state.add_tasklet('einsum_reset', set(), {'out_%s' % output}, 'out_%s = 0' % output) onode = init_state.add_write(output) init_state.add_edge(t, 'out_%s' % output, onode, None, Memlet.simple(output, '0')) wcr = 'lambda a,b: a+b' if is_conflicted else None # Pure einsum map state.add_mapped_tasklet( 'einsum', {k: '0:%s' % v for k, v in chardict.items()}, { 'inp_%s' % arr: Memlet.simple(arr, ','.join(inp)) for inp, arr in zip(einsum.inputs, arrays) }, 'out_%s = %s' % (output, ' * '.join('inp_%s' % arr for arr in arrays)), { 'out_%s' % output: Memlet.simple( output, output_index, wcr_str=wcr) }, input_nodes=input_nodes, output_nodes={output: c}, external_edges=True) else: # Represent einsum as a GEMM or batched GEMM (using library nodes) a_shape = sdfg.arrays[arrays[0]].shape b_shape = sdfg.arrays[arrays[1]].shape c_shape = output_shape a = input_nodes[arrays[0]] b = input_nodes[arrays[1]] c = state.add_write(output) # Compute GEMM dimensions and strides strides = dict( BATCH=prod([c_shape[dim] for dim in einsum.c_batch]), M=prod([a_shape[dim] for dim in einsum.a_only]), K=prod([a_shape[dim] for dim in einsum.a_sum]), N=prod([b_shape[dim] for dim in einsum.b_only]), sAM=prod(a_shape[einsum.a_only[-1] + 1:]) if einsum.a_only else 1, sAK=prod(a_shape[einsum.a_sum[-1] + 1:]) if einsum.a_sum else 1, sAB=prod(a_shape[einsum.a_batch[-1] + 1:]) if einsum.a_batch else 1, sBK=prod(b_shape[einsum.b_sum[-1] + 1:]) if einsum.b_sum else 1, sBN=prod(b_shape[einsum.b_only[-1] + 1:]) if einsum.b_only else 1, sBB=prod(b_shape[einsum.b_batch[-1] + 1:]) if einsum.b_batch else 1, sCM=prod(c_shape[einsum.c_a_only[-1] + 1:]) if einsum.c_a_only else 1, sCN=prod(c_shape[einsum.c_b_only[-1] + 1:]) if einsum.c_b_only else 1, sCB=prod(c_shape[einsum.c_batch[-1] + 1:]) if einsum.c_batch else 1) # Complement strides to make matrices as necessary if len(a_shape) == 1 and len(einsum.a_sum) == 1: strides['sAK'] = 1 strides['sAB'] = strides['sAM'] = strides['K'] if len(b_shape) == 1 and len(einsum.b_sum) == 1: strides['sBN'] = 1 strides['sBK'] = 1 strides['sBB'] = strides['K'] if len(c_shape) == 1 and len(einsum.a_sum) == len(einsum.b_sum): strides['sCN'] = 1 strides['sCB'] = strides['sCM'] = strides['N'] # Create nested SDFG for GEMM nsdfg = create_batch_gemm_sdfg(dtype, strides) nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'X', 'Y'}, {'Z'}, strides) state.add_edge(a, None, nsdfg_node, 'X', Memlet.from_array(a.data, a.desc(sdfg))) state.add_edge(b, None, nsdfg_node, 'Y', Memlet.from_array(b.data, b.desc(sdfg))) state.add_edge(nsdfg_node, 'Z', c, None, Memlet.from_array(c.data, c.desc(sdfg))) return output, c
def apply(self, state: SDFGState, sdfg: SDFG): adesc = self.a.desc(sdfg) bdesc = self.b.desc(sdfg) edge = state.edges_between(self.a, self.b)[0] if len(adesc.shape) >= len(bdesc.shape): copy_shape = edge.data.get_src_subset(edge, state).size() copy_a = True else: copy_shape = edge.data.get_dst_subset(edge, state).size() copy_a = False maprange = {f'__i{i}': (0, s - 1, 1) for i, s in enumerate(copy_shape)} av = self.a.data bv = self.b.data avnode = self.a bvnode = self.b # Linearize and delinearize to get index expression for other side if copy_a: a_index = [ symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape)) ] b_index = self.delinearize_linearize( bdesc, copy_shape, edge.data.get_dst_subset(edge, state)) else: a_index = self.delinearize_linearize( adesc, copy_shape, edge.data.get_src_subset(edge, state)) b_index = [ symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape)) ] a_subset = subsets.Range([(ind, ind, 1) for ind in a_index]) b_subset = subsets.Range([(ind, ind, 1) for ind in b_index]) # Set schedule based on GPU arrays schedule = dtypes.ScheduleType.Default if adesc.storage == dtypes.StorageType.GPU_Global or bdesc.storage == dtypes.StorageType.GPU_Global: # If already inside GPU kernel if is_devicelevel_gpu(sdfg, state, self.a): schedule = dtypes.ScheduleType.Sequential else: schedule = dtypes.ScheduleType.GPU_Device # Add copy map t, _, _ = state.add_mapped_tasklet( 'copy', maprange, dict(__inp=Memlet(data=av, subset=a_subset)), '__out = __inp', dict(__out=Memlet(data=bv, subset=b_subset)), schedule, external_edges=True, input_nodes={av: avnode}, output_nodes={bv: bvnode}) # Set connector types (due to this transformation appearing in codegen, after connector # types have been resolved) t.in_connectors['__inp'] = adesc.dtype t.out_connectors['__out'] = bdesc.dtype # Remove old edge state.remove_edge(edge)