def make_sdfg(node, parent_state, parent_sdfg): # Get metadata from parent SDFG ((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b, shape_b, strides_b), cdata) = _get_matmul_operands(node, parent_state, parent_sdfg) outedge = parent_state.out_edges(node)[0] cdesc = parent_sdfg.arrays[outedge.data.data] bopt = _get_batchmm_opts(shape_a, strides_a, shape_b, strides_b, cdesc.shape, cdesc.strides) if shape_a[-1] != shape_b[-2]: raise SyntaxError('Matrix sizes must match') if bopt: shape_c = (bopt['b'], shape_a[-2], shape_b[-1]) else: shape_c = (shape_a[-2], shape_b[-1]) dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type dtype_c = cdesc.dtype.type if outer_array_a.storage != outer_array_b.storage: raise ValueError("Input matrices must have same storage") storage = outer_array_a.storage # Create replacement SDFG sdfg = dace.SDFG(node.label + "_sdfg") _, array_a = sdfg.add_array("_a", shape_a, dtype_a, strides=strides_a, storage=storage) _, array_b = sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=storage) _, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-1], storage=storage) # Add an initialization state init_state = sdfg.add_state() init_state.add_mapped_tasklet( 'batched_matmul_init', {'_o%d' % i: '0:%s' % symstr(d) for i, d in enumerate(shape_c)}, {}, 'out = 0', { 'out': dace.Memlet.simple( '_c', ','.join(['_o%d' % i for i in range(len(shape_c))])) }, external_edges=True) state = sdfg.add_state_after(init_state, node.label + "_state") state.add_mapped_tasklet( '_BatchedBatchedMatMult_', { '__i%d' % i: '0:%s' % s for i, s in enumerate([ bopt['b'], array_a.shape[-2], array_b.shape[-1], array_a.shape[-1] ]) }, { '__a': dace.Memlet.simple("_a", ('__i1, __i3' if len(array_a.shape) == 2 else '__i0, __i1, __i3')), '__b': dace.Memlet.simple("_b", ('__i3, __i2' if len(array_b.shape) == 2 else '__i0, __i3, __i2')) }, '__c = __a * __b', { '__c': dace.Memlet.simple( "_c", '__i0, __i1, __i2', wcr_str='lambda x, y: x + y') }, external_edges=True) return sdfg
def _matmult(visitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): from dace.libraries.blas.nodes.matmul import MatMul # Avoid import loop arr1 = sdfg.arrays[op1] arr2 = sdfg.arrays[op2] if len(arr1.shape) > 1 and len(arr2.shape) > 1: # matrix * matrix if len(arr1.shape) > 3 or len(arr2.shape) > 3: raise SyntaxError( 'Matrix multiplication of tensors of dimensions > 3 ' 'not supported') if arr1.shape[-1] != arr2.shape[-2]: raise SyntaxError('Matrix dimension mismatch %s != %s' % (arr1.shape[-1], arr2.shape[-2])) from dace.libraries.blas.nodes.matmul import _get_batchmm_opts # Determine batched multiplication bopt = _get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape, arr2.strides, None, None) if bopt: output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1]) else: output_shape = (arr1.shape[-2], arr2.shape[-1]) elif len(arr1.shape) == 2 and len(arr2.shape) == 1: # matrix * vector if arr1.shape[1] != arr2.shape[0]: raise SyntaxError("Number of matrix columns {} must match" "size of vector {}.".format( arr1.shape[1], arr2.shape[0])) output_shape = (arr1.shape[0], ) elif len(arr1.shape) == 1 and len(arr2.shape) == 1: # vector * vector if arr1.shape[0] != arr2.shape[0]: raise SyntaxError("Vectors in vector product must have same size: " "{} vs. {}".format(arr1.shape[0], arr2.shape[0])) output_shape = (1, ) else: # Dunno what this is, bail raise SyntaxError( "Cannot multiply arrays with shapes: {} and {}".format( arr1.shape, arr2.shape)) type1 = arr1.dtype.type type2 = arr2.dtype.type restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type] op3, arr3 = sdfg.add_temp_transient(output_shape, restype, arr1.storage) acc1 = state.add_read(op1) acc2 = state.add_read(op2) acc3 = state.add_write(op3) tasklet = MatMul('_MatMult_', restype) state.add_node(tasklet) state.add_edge(acc1, None, tasklet, '_a', dace.Memlet.from_array(op1, arr1)) state.add_edge(acc2, None, tasklet, '_b', dace.Memlet.from_array(op2, arr2)) state.add_edge(tasklet, '_c', acc3, None, dace.Memlet.from_array(op3, arr3)) return op3