Exemplo n.º 1
0
    def make_sdfg(node, parent_state, parent_sdfg):
        # Get metadata from parent SDFG
        ((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b,
                                                       shape_b, strides_b),
         cdata) = _get_matmul_operands(node, parent_state, parent_sdfg)
        outedge = parent_state.out_edges(node)[0]
        cdesc = parent_sdfg.arrays[outedge.data.data]
        bopt = _get_batchmm_opts(shape_a, strides_a, shape_b, strides_b,
                                 cdesc.shape, cdesc.strides)

        if shape_a[-1] != shape_b[-2]:
            raise SyntaxError('Matrix sizes must match')
        if bopt:
            shape_c = (bopt['b'], shape_a[-2], shape_b[-1])
        else:
            shape_c = (shape_a[-2], shape_b[-1])

        dtype_a = outer_array_a.dtype.type
        dtype_b = outer_array_b.dtype.type
        dtype_c = cdesc.dtype.type

        if outer_array_a.storage != outer_array_b.storage:
            raise ValueError("Input matrices must have same storage")
        storage = outer_array_a.storage

        # Create replacement SDFG
        sdfg = dace.SDFG(node.label + "_sdfg")

        _, array_a = sdfg.add_array("_a",
                                    shape_a,
                                    dtype_a,
                                    strides=strides_a,
                                    storage=storage)
        _, array_b = sdfg.add_array("_b",
                                    shape_b,
                                    dtype_b,
                                    strides=strides_b,
                                    storage=storage)
        _, array_c = sdfg.add_array("_c",
                                    shape_c,
                                    dtype_c,
                                    strides=cdata[-1],
                                    storage=storage)

        # Add an initialization state
        init_state = sdfg.add_state()
        init_state.add_mapped_tasklet(
            'batched_matmul_init',
            {'_o%d' % i: '0:%s' % symstr(d)
             for i, d in enumerate(shape_c)}, {},
            'out = 0', {
                'out':
                dace.Memlet.simple(
                    '_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))
            },
            external_edges=True)

        state = sdfg.add_state_after(init_state, node.label + "_state")

        state.add_mapped_tasklet(
            '_BatchedBatchedMatMult_', {
                '__i%d' % i: '0:%s' % s
                for i, s in enumerate([
                    bopt['b'], array_a.shape[-2], array_b.shape[-1],
                    array_a.shape[-1]
                ])
            }, {
                '__a':
                dace.Memlet.simple("_a", ('__i1, __i3' if len(array_a.shape)
                                          == 2 else '__i0, __i1, __i3')),
                '__b':
                dace.Memlet.simple("_b", ('__i3, __i2' if len(array_b.shape)
                                          == 2 else '__i0, __i3, __i2'))
            },
            '__c = __a * __b', {
                '__c':
                dace.Memlet.simple(
                    "_c", '__i0, __i1, __i2', wcr_str='lambda x, y: x + y')
            },
            external_edges=True)

        return sdfg
Exemplo n.º 2
0
def _matmult(visitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str):

    from dace.libraries.blas.nodes.matmul import MatMul  # Avoid import loop

    arr1 = sdfg.arrays[op1]
    arr2 = sdfg.arrays[op2]

    if len(arr1.shape) > 1 and len(arr2.shape) > 1:  # matrix * matrix

        if len(arr1.shape) > 3 or len(arr2.shape) > 3:
            raise SyntaxError(
                'Matrix multiplication of tensors of dimensions > 3 '
                'not supported')

        if arr1.shape[-1] != arr2.shape[-2]:
            raise SyntaxError('Matrix dimension mismatch %s != %s' %
                              (arr1.shape[-1], arr2.shape[-2]))

        from dace.libraries.blas.nodes.matmul import _get_batchmm_opts

        # Determine batched multiplication
        bopt = _get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape,
                                 arr2.strides, None, None)
        if bopt:
            output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1])
        else:
            output_shape = (arr1.shape[-2], arr2.shape[-1])

    elif len(arr1.shape) == 2 and len(arr2.shape) == 1:  # matrix * vector

        if arr1.shape[1] != arr2.shape[0]:
            raise SyntaxError("Number of matrix columns {} must match"
                              "size of vector {}.".format(
                                  arr1.shape[1], arr2.shape[0]))

        output_shape = (arr1.shape[0], )

    elif len(arr1.shape) == 1 and len(arr2.shape) == 1:  # vector * vector

        if arr1.shape[0] != arr2.shape[0]:
            raise SyntaxError("Vectors in vector product must have same size: "
                              "{} vs. {}".format(arr1.shape[0], arr2.shape[0]))

        output_shape = (1, )

    else:  # Dunno what this is, bail

        raise SyntaxError(
            "Cannot multiply arrays with shapes: {} and {}".format(
                arr1.shape, arr2.shape))

    type1 = arr1.dtype.type
    type2 = arr2.dtype.type
    restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type]

    op3, arr3 = sdfg.add_temp_transient(output_shape, restype, arr1.storage)

    acc1 = state.add_read(op1)
    acc2 = state.add_read(op2)
    acc3 = state.add_write(op3)

    tasklet = MatMul('_MatMult_', restype)
    state.add_node(tasklet)
    state.add_edge(acc1, None, tasklet, '_a', dace.Memlet.from_array(op1, arr1))
    state.add_edge(acc2, None, tasklet, '_b', dace.Memlet.from_array(op2, arr2))
    state.add_edge(tasklet, '_c', acc3, None, dace.Memlet.from_array(op3, arr3))

    return op3