Example #1
0
def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str):

    from dace.libraries.mpi.nodes.wait import Wait

    libnode = Wait('_Wait_')

    req_range = None
    if isinstance(request, tuple):
        req_name, req_range = request
    else:
        req_name = request

    desc = sdfg.arrays[req_name]
    req_node = state.add_access(req_name)

    src = sdfg.add_temp_transient([1], dtypes.int32)
    src_node = state.add_write(src[0])
    tag = sdfg.add_temp_transient([1], dtypes.int32)
    tag_node = state.add_write(tag[0])

    if req_range:
        req_mem = Memlet.simple(req_name, req_range)
    else:
        req_mem = Memlet.from_array(req_name, desc)

    state.add_edge(req_node, None, libnode, '_request', req_mem)
    state.add_edge(libnode, '_stat_source', src_node, None,
                   Memlet.from_array(*src))
    state.add_edge(libnode, '_stat_tag', tag_node, None,
                   Memlet.from_array(*tag))

    return None
Example #2
0
def _reduce(sdfg: SDFG,
            state: SDFGState,
            redfunction: Callable[[Any, Any], Any],
            in_array: str,
            out_array=None,
            axis=None,
            identity=None):
    if out_array is None:
        inarr = in_array
        # Convert axes to tuple
        if axis is not None and not isinstance(axis, (tuple, list)):
            axis = (axis, )
        if axis is not None:
            axis = tuple(pystr_to_symbolic(a) for a in axis)
        input_subset = parse_memlet_subset(sdfg.arrays[inarr],
                                           ast.parse(in_array).body[0].value,
                                           {})
        input_memlet = Memlet.simple(inarr, input_subset)
        output_shape = None
        if axis is None:
            output_shape = [1]
        else:
            output_subset = copy.deepcopy(input_subset)
            output_subset.pop(axis)
            output_shape = output_subset.size()
        outarr, arr = sdfg.add_temp_transient(output_shape,
                                              sdfg.arrays[inarr].dtype,
                                              sdfg.arrays[inarr].storage)
        output_memlet = Memlet.from_array(outarr, arr)
    else:
        inarr = in_array
        outarr = out_array

        # Convert axes to tuple
        if axis is not None and not isinstance(axis, (tuple, list)):
            axis = (axis, )
        if axis is not None:
            axis = tuple(pystr_to_symbolic(a) for a in axis)

        # Compute memlets
        input_subset = parse_memlet_subset(sdfg.arrays[inarr],
                                           ast.parse(in_array).body[0].value,
                                           {})
        input_memlet = Memlet.simple(inarr, input_subset)
        output_subset = parse_memlet_subset(sdfg.arrays[outarr],
                                            ast.parse(out_array).body[0].value,
                                            {})
        output_memlet = Memlet.simple(outarr, output_subset)

    # Create reduce subgraph
    inpnode = state.add_read(inarr)
    rednode = state.add_reduce(redfunction, axis, identity)
    outnode = state.add_write(outarr)
    state.add_nedge(inpnode, rednode, input_memlet)
    state.add_nedge(rednode, outnode, output_memlet)

    if out_array is None:
        return outarr
    else:
        return []
Example #3
0
def _assignop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str):
    """ Implements a general element-wise array assignment operator. """
    arr1 = sdfg.arrays[op1]

    name, _ = sdfg.add_temp_transient(arr1.shape, arr1.dtype, arr1.storage)
    write_memlet = None
    if opcode:
        write_memlet = Memlet.simple(
            name,
            ','.join(['__i%d' % i for i in range(len(arr1.shape))]),
            wcr_str='lambda x, y: x %s y' % opcode)
    else:
        write_memlet = Memlet.simple(
            name, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))
    state.add_mapped_tasklet(
        "_%s_" % opname,
        {'__i%d' % i: '0:%s' % s
         for i, s in enumerate(arr1.shape)}, {
             '__in1':
             Memlet.simple(
                 op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))
         },
        '__out = __in1', {'__out': write_memlet},
        external_edges=True)
    return name
Example #4
0
def _define_local_ex(sdfg: SDFG,
                     state: SDFGState,
                     shape: Shape,
                     dtype: dace.typeclass,
                     storage: dtypes.StorageType = dtypes.StorageType.Default):
    """ Defines a local array in a DaCe program. """
    name, _ = sdfg.add_temp_transient(shape, dtype, storage=storage)
    return name
Example #5
0
def eye(sdfg: SDFG, state: SDFGState, N, M=None, k=0, dtype=dace.float64):
    M = M or N
    name, _ = sdfg.add_temp_transient([N, M], dtype)

    state.add_mapped_tasklet('eye',
                             dict(i='0:%s' % N, j='0:%s' % M), {},
                             'val = 1 if i == (j - %s) else 0' % k,
                             dict(val=dace.Memlet.simple(name, 'i, j')),
                             external_edges=True)

    return name
Example #6
0
def _transpose(sdfg: SDFG, state: SDFGState, inpname: str):

    arr1 = sdfg.arrays[inpname]
    restype = arr1.dtype
    outname, arr2 = sdfg.add_temp_transient((arr1.shape[1], arr1.shape[0]),
                                            restype, arr1.storage)

    acc1 = state.add_read(inpname)
    acc2 = state.add_write(outname)
    import dace.libraries.blas  # Avoid import loop
    tasklet = dace.libraries.blas.Transpose('_Transpose_', restype)
    state.add_node(tasklet)
    state.add_edge(acc1, None, tasklet, '_inp',
                   dace.Memlet.from_array(inpname, arr1))
    state.add_edge(tasklet, '_out', acc2, None,
                   dace.Memlet.from_array(outname, arr2))

    return outname
Example #7
0
def _simple_call(sdfg: SDFG,
                 state: SDFGState,
                 inpname: str,
                 func: str,
                 restype: dace.typeclass = None):
    """ Implements a simple call of the form `out = func(inp)`. """
    inparr = sdfg.arrays[inpname]
    if restype is None:
        restype = sdfg.arrays[inpname].dtype
    outname, outarr = sdfg.add_temp_transient(inparr.shape, restype,
                                              inparr.storage)
    num_elements = reduce(lambda x, y: x * y, inparr.shape)
    if num_elements == 1:
        inp = state.add_read(inpname)
        out = state.add_write(outname)
        tasklet = state.add_tasklet(func, {'__inp'}, {'__out'},
                                    '__out = {f}(__inp)'.format(f=func))
        state.add_edge(inp, None, tasklet, '__inp',
                       Memlet.from_array(inpname, inparr))
        state.add_edge(tasklet, '__out', out, None,
                       Memlet.from_array(outname, outarr))
    else:
        state.add_mapped_tasklet(
            name=func,
            map_ranges={
                '__i%d' % i: '0:%s' % n
                for i, n in enumerate(inparr.shape)
            },
            inputs={
                '__inp':
                Memlet.simple(
                    inpname,
                    ','.join(['__i%d' % i for i in range(len(inparr.shape))]))
            },
            code='__out = {f}(__inp)'.format(f=func),
            outputs={
                '__out':
                Memlet.simple(
                    outname,
                    ','.join(['__i%d' % i for i in range(len(inparr.shape))]))
            },
            external_edges=True)

    return outname
Example #8
0
def _binop(sdfg: SDFG, state: SDFGState, op1: str, op2: str, opcode: str,
           opname: str, restype: dace.typeclass):
    """ Implements a general element-wise array binary operator. """
    arr1 = sdfg.arrays[op1]
    arr2 = sdfg.arrays[op2]

    out_shape, all_idx_dict, all_idx, arr1_idx, arr2_idx = _broadcast_together(
        arr1.shape, arr2.shape)

    name, _ = sdfg.add_temp_transient(out_shape, restype, arr1.storage)
    state.add_mapped_tasklet("_%s_" % opname,
                             all_idx_dict, {
                                 '__in1': Memlet.simple(op1, arr1_idx),
                                 '__in2': Memlet.simple(op2, arr2_idx)
                             },
                             '__out = __in1 %s __in2' % opcode,
                             {'__out': Memlet.simple(name, all_idx)},
                             external_edges=True)
    return name
Example #9
0
def _unop(sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str):
    """ Implements a general element-wise array unary operator. """
    arr1 = sdfg.arrays[op1]

    name, _ = sdfg.add_temp_transient(arr1.shape, arr1.dtype, arr1.storage)
    state.add_mapped_tasklet(
        "_%s_" % opname,
        {'__i%d' % i: '0:%s' % s
         for i, s in enumerate(arr1.shape)}, {
             '__in1':
             Memlet.simple(
                 op1, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))
         },
        '__out = %s __in1' % opcode, {
            '__out':
            Memlet.simple(
                name, ','.join(['__i%d' % i for i in range(len(arr1.shape))]))
        },
        external_edges=True)
    return name
Example #10
0
def _array_x_binop(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState,
                   op1: str, op2: str, op: str, opcode: str):

    arr1 = sdfg.arrays[op1]
    type1 = arr1.dtype.type
    isscal1 = _is_scalar(sdfg, op1)
    isnum1 = isscal1 and (op1 in visitor.numbers.values())
    if isnum1:
        type1 = inverse_dict_lookup(visitor.numbers, op1)
    arr2 = sdfg.arrays[op2]
    type2 = arr2.dtype.type
    isscal2 = _is_scalar(sdfg, op2)
    isnum2 = isscal2 and (op2 in visitor.numbers.values())
    if isnum2:
        type2 = inverse_dict_lookup(visitor.numbers, op2)
    if _is_op_boolean(op):
        restype = dace.bool
    else:
        restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type]

    if isscal1 and isscal2:
        arr1 = sdfg.arrays[op1]
        arr2 = sdfg.arrays[op2]
        op3, arr3 = sdfg.add_temp_transient([1], restype, arr2.storage)
        tasklet = state.add_tasklet('_SS%s_' % op, {'s1', 's2'}, {'s3'},
                                    's3 = s1 %s s2' % opcode)
        n1 = state.add_read(op1)
        n2 = state.add_read(op2)
        n3 = state.add_write(op3)
        state.add_edge(n1, None, tasklet, 's1',
                       dace.Memlet.from_array(op1, arr1))
        state.add_edge(n2, None, tasklet, 's2',
                       dace.Memlet.from_array(op2, arr2))
        state.add_edge(tasklet, 's3', n3, None,
                       dace.Memlet.from_array(op3, arr3))
        return op3
    else:
        return _binop(sdfg, state, op1, op2, opcode, op, restype)
Example #11
0
def _distr_matmult(pv: 'ProgramVisitor',
                   sdfg: SDFG,
                   state: SDFGState,
                   opa: str,
                   opb: str,
                   shape: Sequence[Union[sp.Expr, Number]],
                   a_block_sizes: Union[str, Sequence[Union[sp.Expr,
                                                            Number]]] = None,
                   b_block_sizes: Union[str, Sequence[Union[sp.Expr,
                                                            Number]]] = None,
                   c_block_sizes: Union[str, Sequence[Union[sp.Expr,
                                                            Number]]] = None):

    arra = sdfg.arrays[opa]
    arrb = sdfg.arrays[opb]

    if len(shape) == 3:
        gm, gn, gk = shape
    else:
        gm, gn = shape

    a_block_sizes = a_block_sizes or arra.shape
    if len(a_block_sizes) < 2:
        a_block_sizes = (a_block_sizes[0], 1)
    b_block_sizes = b_block_sizes or arrb.shape
    if len(b_block_sizes) < 2:
        b_block_sizes = (b_block_sizes[0], 1)

    if len(arra.shape) == 1 and len(arrb.shape) == 2:
        a_block_sizes, b_block_sizes = b_block_sizes, a_block_sizes

    a_bsizes_range = None
    if isinstance(a_block_sizes, (list, tuple)):
        if isinstance(a_block_sizes[0], str):
            a_bsizes_name, a_bsizes_range = a_block_sizes
            a_bsizes_desc = sdfg.arrays[a_bsizes_name]
            a_bsizes_node = state.add_read(a_bsizes_name)
        else:
            a_bsizes_name, a_bsizes_desc = sdfg.add_temp_transient(
                (len(a_block_sizes), ), dtype=dace.int32)
            a_bsizes_node = state.add_access(a_bsizes_name)
            a_bsizes_tasklet = state.add_tasklet(
                '_set_a_bsizes_', {}, {'__out'}, ";".join([
                    "__out[{}] = {}".format(i, sz)
                    for i, sz in enumerate(a_block_sizes)
                ]))
            state.add_edge(a_bsizes_tasklet, '__out', a_bsizes_node, None,
                           Memlet.from_array(a_bsizes_name, a_bsizes_desc))
    else:
        a_bsizes_name = a_block_sizes
        a_bsizes_desc = sdfg.arrays[a_bsizes_name]
        a_bsizes_node = state.add_read(a_bsizes_name)

    b_bsizes_range = None
    if isinstance(a_block_sizes, (list, tuple)):
        if isinstance(a_block_sizes[0], str):
            b_bsizes_name, b_sizes_range = b_block_sizes
            b_bsizes_desc = sdfg.arrays[b_bsizes_name]
            b_bsizes_node = state.add_read(b_bsizes_name)
        else:
            b_bsizes_name, b_bsizes_desc = sdfg.add_temp_transient(
                (len(b_block_sizes), ), dtype=dace.int32)
            b_bsizes_node = state.add_access(b_bsizes_name)
            b_bsizes_tasklet = state.add_tasklet(
                '_set_b_sizes_', {}, {'__out'}, ";".join([
                    "__out[{}] = {}".format(i, sz)
                    for i, sz in enumerate(b_block_sizes)
                ]))
            state.add_edge(b_bsizes_tasklet, '__out', b_bsizes_node, None,
                           Memlet.from_array(b_bsizes_name, b_bsizes_desc))
    else:
        b_bsizes_name = b_block_sizes
        b_bsizes_desc = sdfg.arrays[b_bsizes_name]
        b_bsizes_node = state.add_read(b_bsizes_name)

    if len(arra.shape) == 2 and len(arrb.shape) == 2:
        # Gemm
        from dace.libraries.pblas.nodes.pgemm import Pgemm
        tasklet = Pgemm("__DistrMatMult__", gm, gn, gk)
        m = arra.shape[0]
        n = arrb.shape[-1]
        out = sdfg.add_temp_transient((m, n), dtype=arra.dtype)
    elif len(arra.shape) == 2 and len(arrb.shape) == 1:
        # Gemv
        from dace.libraries.pblas.nodes.pgemv import Pgemv
        tasklet = Pgemv("__DistrMatVecMult__", m=gm, n=gn)
        if c_block_sizes:
            m = c_block_sizes[0]
        else:
            m = arra.shape[0]
        out = sdfg.add_temp_transient((m, ), dtype=arra.dtype)
    elif len(arra.shape) == 1 and len(arrb.shape) == 2:
        # Gemv transposed
        # Swap a and b
        opa, opb = opb, opa
        arra, arrb = arrb, arra
        from dace.libraries.pblas.nodes.pgemv import Pgemv
        tasklet = Pgemv("__DistrMatVecMult__", transa='T', m=gm, n=gn)
        if c_block_sizes:
            n = c_block_sizes[0]
        else:
            n = arra.shape[1]
        out = sdfg.add_temp_transient((n, ), dtype=arra.dtype)

    anode = state.add_read(opa)
    bnode = state.add_read(opb)
    cnode = state.add_write(out[0])

    if a_bsizes_range:
        a_bsizes_mem = Memlet.simple(a_bsizes_name, a_bsizes_range)
    else:
        a_bsizes_mem = Memlet.from_array(a_bsizes_name, a_bsizes_desc)
    if b_bsizes_range:
        b_bsizes_mem = Memlet.simple(b_bsizes_name, b_bsizes_range)
    else:
        b_bsizes_mem = Memlet.from_array(b_bsizes_name, b_bsizes_desc)

    state.add_edge(anode, None, tasklet, '_a', Memlet.from_array(opa, arra))
    state.add_edge(bnode, None, tasklet, '_b', Memlet.from_array(opb, arrb))
    state.add_edge(a_bsizes_node, None, tasklet, '_a_block_sizes',
                   a_bsizes_mem)
    state.add_edge(b_bsizes_node, None, tasklet, '_b_block_sizes',
                   b_bsizes_mem)
    state.add_edge(tasklet, '_c', cnode, None, Memlet.from_array(*out))

    return out[0]
Example #12
0
def _matmult(visitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str):

    from dace.libraries.blas.nodes.matmul import MatMul  # Avoid import loop

    arr1 = sdfg.arrays[op1]
    arr2 = sdfg.arrays[op2]

    if len(arr1.shape) > 1 and len(arr2.shape) > 1:  # matrix * matrix

        if len(arr1.shape) > 3 or len(arr2.shape) > 3:
            raise SyntaxError(
                'Matrix multiplication of tensors of dimensions > 3 '
                'not supported')

        if arr1.shape[-1] != arr2.shape[-2]:
            raise SyntaxError('Matrix dimension mismatch %s != %s' %
                              (arr1.shape[-1], arr2.shape[-2]))

        from dace.libraries.blas.nodes.matmul import _get_batchmm_opts

        # Determine batched multiplication
        bopt = _get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape,
                                 arr2.strides, None, None)
        if bopt:
            output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1])
        else:
            output_shape = (arr1.shape[-2], arr2.shape[-1])

    elif len(arr1.shape) == 2 and len(arr2.shape) == 1:  # matrix * vector

        if arr1.shape[1] != arr2.shape[0]:
            raise SyntaxError("Number of matrix columns {} must match"
                              "size of vector {}.".format(
                                  arr1.shape[1], arr2.shape[0]))

        output_shape = (arr1.shape[0], )

    elif len(arr1.shape) == 1 and len(arr2.shape) == 1:  # vector * vector

        if arr1.shape[0] != arr2.shape[0]:
            raise SyntaxError("Vectors in vector product must have same size: "
                              "{} vs. {}".format(arr1.shape[0], arr2.shape[0]))

        output_shape = (1, )

    else:  # Dunno what this is, bail

        raise SyntaxError(
            "Cannot multiply arrays with shapes: {} and {}".format(
                arr1.shape, arr2.shape))

    type1 = arr1.dtype.type
    type2 = arr2.dtype.type
    restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type]

    op3, arr3 = sdfg.add_temp_transient(output_shape, restype, arr1.storage)

    acc1 = state.add_read(op1)
    acc2 = state.add_read(op2)
    acc3 = state.add_write(op3)

    tasklet = MatMul('_MatMult_', restype)
    state.add_node(tasklet)
    state.add_edge(acc1, None, tasklet, '_a', dace.Memlet.from_array(op1, arr1))
    state.add_edge(acc2, None, tasklet, '_b', dace.Memlet.from_array(op2, arr2))
    state.add_edge(tasklet, '_c', acc3, None, dace.Memlet.from_array(op3, arr3))

    return op3
Example #13
0
def _create_einsum_internal(sdfg: SDFG,
                            state: SDFGState,
                            einsum_string: str,
                            *arrays: str,
                            dtype: Optional[dtypes.typeclass] = None,
                            optimize: bool = False,
                            output: Optional[str] = None,
                            nodes: Optional[Dict[str, AccessNode]] = None,
                            init_output: bool = None):
    # Infer shapes and strides of input/output arrays
    einsum = EinsumParser(einsum_string)

    if len(einsum.inputs) != len(arrays):
        raise ValueError('Invalid number of arrays for einsum expression')

    # Get shapes from arrays and verify dimensionality
    chardict = {}
    for inp, inpname in zip(einsum.inputs, arrays):
        inparr = sdfg.arrays[inpname]
        if len(inp) != len(inparr.shape):
            raise ValueError('Dimensionality mismatch in input "%s"' % inpname)
        for char, shp in zip(inp, inparr.shape):
            if char in chardict and shp != chardict[char]:
                raise ValueError('Dimension mismatch in einsum expression')
            chardict[char] = shp

    if optimize:
        # Try to import opt_einsum
        try:
            import opt_einsum as oe
        except (ModuleNotFoundError, NameError, ImportError):
            raise ImportError('To optimize einsum expressions, please install '
                              'the "opt_einsum" package.')

        for char, shp in chardict.items():
            if symbolic.issymbolic(shp):
                raise ValueError('Einsum optimization cannot be performed '
                                 'on symbolically-sized array dimension "%s" '
                                 'for subscript character "%s"' % (shp, char))

        # Create optimal contraction path
        # noinspection PyTypeChecker
        _, path_info = oe.contract_path(
            einsum_string, *oe.helpers.build_views(einsum_string, chardict))

        input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays}
        result_node = None

        # Follow path and create a chain of operation SDFG states
        for pair, nonfree, expr, after, blas in path_info.contraction_list:
            result, result_node = _create_einsum_internal(sdfg,
                                                          state,
                                                          expr,
                                                          arrays[pair[0]],
                                                          arrays[pair[1]],
                                                          dtype=dtype,
                                                          optimize=False,
                                                          output=None,
                                                          nodes=input_nodes)
            arrays = ([a for i, a in enumerate(arrays) if i not in pair] +
                      [result])
            input_nodes[result] = result_node

        return arrays[0], result_node
        # END of einsum optimization

    input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays}

    # Get output shape from chardict, or [1] for a scalar output
    output_shape = list(map(lambda k: chardict[k], einsum.output)) or [1]
    output_index = ','.join(o for o in einsum.output) or '0'

    if output is None:
        dtype = dtype or sdfg.arrays[arrays[0]].dtype
        output, odesc = sdfg.add_temp_transient(output_shape, dtype)
        to_init = True
    else:
        odesc = sdfg.arrays[output]
        dtype = dtype or odesc.dtype
        to_init = init_output or True

    is_conflicted = not all(
        all(indim in einsum.output for indim in inp) for inp in einsum.inputs)
    if not is_conflicted and init_output is None:
        to_init = False

    if not einsum.is_bmm():
        # Fall back to "pure" SDFG einsum with conflict resolution
        c = state.add_write(output)

        # Add state before this one to initialize the output value
        if to_init:
            init_state = sdfg.add_state_before(state)
            if len(einsum.output) > 0:
                init_state.add_mapped_tasklet(
                    'einsum_reset',
                    {k: '0:%s' % chardict[k]
                     for k in einsum.output}, {},
                    'out_%s = 0' % output,
                    {'out_%s' % output: Memlet.simple(output, output_index)},
                    external_edges=True)
            else:  # Scalar output
                t = init_state.add_tasklet('einsum_reset', set(),
                                           {'out_%s' % output},
                                           'out_%s = 0' % output)
                onode = init_state.add_write(output)
                init_state.add_edge(t, 'out_%s' % output, onode, None,
                                    Memlet.simple(output, '0'))

        wcr = 'lambda a,b: a+b' if is_conflicted else None
        # Pure einsum map
        state.add_mapped_tasklet(
            'einsum', {k: '0:%s' % v
                       for k, v in chardict.items()}, {
                           'inp_%s' % arr: Memlet.simple(arr, ','.join(inp))
                           for inp, arr in zip(einsum.inputs, arrays)
                       },
            'out_%s = %s' % (output, ' * '.join('inp_%s' % arr
                                                for arr in arrays)),
            {
                'out_%s' % output: Memlet.simple(
                    output, output_index, wcr_str=wcr)
            },
            input_nodes=input_nodes,
            output_nodes={output: c},
            external_edges=True)
    else:
        # Represent einsum as a GEMM or batched GEMM (using library nodes)
        a_shape = sdfg.arrays[arrays[0]].shape
        b_shape = sdfg.arrays[arrays[1]].shape
        c_shape = output_shape

        a = input_nodes[arrays[0]]
        b = input_nodes[arrays[1]]
        c = state.add_write(output)

        # Compute GEMM dimensions and strides
        strides = dict(
            BATCH=prod([c_shape[dim] for dim in einsum.c_batch]),
            M=prod([a_shape[dim] for dim in einsum.a_only]),
            K=prod([a_shape[dim] for dim in einsum.a_sum]),
            N=prod([b_shape[dim] for dim in einsum.b_only]),
            sAM=prod(a_shape[einsum.a_only[-1] + 1:]) if einsum.a_only else 1,
            sAK=prod(a_shape[einsum.a_sum[-1] + 1:]) if einsum.a_sum else 1,
            sAB=prod(a_shape[einsum.a_batch[-1] +
                             1:]) if einsum.a_batch else 1,
            sBK=prod(b_shape[einsum.b_sum[-1] + 1:]) if einsum.b_sum else 1,
            sBN=prod(b_shape[einsum.b_only[-1] + 1:]) if einsum.b_only else 1,
            sBB=prod(b_shape[einsum.b_batch[-1] +
                             1:]) if einsum.b_batch else 1,
            sCM=prod(c_shape[einsum.c_a_only[-1] +
                             1:]) if einsum.c_a_only else 1,
            sCN=prod(c_shape[einsum.c_b_only[-1] +
                             1:]) if einsum.c_b_only else 1,
            sCB=prod(c_shape[einsum.c_batch[-1] +
                             1:]) if einsum.c_batch else 1)

        # Complement strides to make matrices as necessary
        if len(a_shape) == 1 and len(einsum.a_sum) == 1:
            strides['sAK'] = 1
            strides['sAB'] = strides['sAM'] = strides['K']
        if len(b_shape) == 1 and len(einsum.b_sum) == 1:
            strides['sBN'] = 1
            strides['sBK'] = 1
            strides['sBB'] = strides['K']
        if len(c_shape) == 1 and len(einsum.a_sum) == len(einsum.b_sum):
            strides['sCN'] = 1
            strides['sCB'] = strides['sCM'] = strides['N']

        # Create nested SDFG for GEMM
        nsdfg = create_batch_gemm_sdfg(dtype, strides)

        nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'X', 'Y'}, {'Z'},
                                           strides)
        state.add_edge(a, None, nsdfg_node, 'X',
                       Memlet.from_array(a.data, a.desc(sdfg)))
        state.add_edge(b, None, nsdfg_node, 'Y',
                       Memlet.from_array(b.data, b.desc(sdfg)))
        state.add_edge(nsdfg_node, 'Z', c, None,
                       Memlet.from_array(c.data, c.desc(sdfg)))

    return output, c
Example #14
0
def _argminmax(sdfg: SDFG,
               state: SDFGState,
               a: str,
               axis,
               func,
               result_type=dace.int32,
               return_both=False):
    nest = NestedCall(sdfg, state)

    assert func in ['min', 'max']

    if axis is None or type(axis) is not int:
        raise SyntaxError('Axis must be an int')

    a_arr = sdfg.arrays[a]

    if not 0 <= axis < len(a_arr.shape):
        raise SyntaxError("Expected 0 <= axis < len({}.shape), got {}".format(
            a, axis))

    reduced_shape = list(copy.deepcopy(a_arr.shape))
    reduced_shape.pop(axis)

    val_and_idx = dace.struct('_val_and_idx', val=a_arr.dtype, idx=result_type)

    # HACK: since identity cannot be specified for structs, we have to init the output array
    reduced_structs, reduced_struct_arr = sdfg.add_temp_transient(
        reduced_shape, val_and_idx)

    code = "__init = _val_and_idx(val={}, idx=-1)".format(
        dtypes.min_value(a_arr.dtype) if func ==
        'max' else dtypes.max_value(a_arr.dtype))

    nest.add_state().add_mapped_tasklet(
        name="_arg{}_convert_".format(func),
        map_ranges={
            '__i%d' % i: '0:%s' % n
            for i, n in enumerate(a_arr.shape) if i != axis
        },
        inputs={},
        code=code,
        outputs={
            '__init':
            Memlet.simple(
                reduced_structs, ','.join('__i%d' % i
                                          for i in range(len(a_arr.shape))
                                          if i != axis))
        },
        external_edges=True)

    nest.add_state().add_mapped_tasklet(
        name="_arg{}_reduce_".format(func),
        map_ranges={'__i%d' % i: '0:%s' % n
                    for i, n in enumerate(a_arr.shape)},
        inputs={
            '__in':
            Memlet.simple(
                a, ','.join('__i%d' % i for i in range(len(a_arr.shape))))
        },
        code="__out = _val_and_idx(idx={}, val=__in)".format("__i%d" % axis),
        outputs={
            '__out':
            Memlet.simple(
                reduced_structs,
                ','.join('__i%d' % i for i in range(len(a_arr.shape))
                         if i != axis),
                wcr_str=("lambda x, y:"
                         "_val_and_idx(val={}(x.val, y.val), "
                         "idx=(y.idx if x.val {} y.val else x.idx))").format(
                             func, '<' if func == 'max' else '>'))
        },
        external_edges=True)

    if return_both:
        outidx, outidxarr = sdfg.add_temp_transient(
            sdfg.arrays[reduced_structs].shape, result_type)
        outval, outvalarr = sdfg.add_temp_transient(
            sdfg.arrays[reduced_structs].shape, a_arr.dtype)

        nest.add_state().add_mapped_tasklet(
            name="_arg{}_extract_".format(func),
            map_ranges={
                '__i%d' % i: '0:%s' % n
                for i, n in enumerate(a_arr.shape) if i != axis
            },
            inputs={
                '__in':
                Memlet.simple(
                    reduced_structs, ','.join('__i%d' % i
                                              for i in range(len(a_arr.shape))
                                              if i != axis))
            },
            code="__out_val = __in.val\n__out_idx = __in.idx",
            outputs={
                '__out_val':
                Memlet.simple(
                    outval, ','.join('__i%d' % i
                                     for i in range(len(a_arr.shape))
                                     if i != axis)),
                '__out_idx':
                Memlet.simple(
                    outidx, ','.join('__i%d' % i
                                     for i in range(len(a_arr.shape))
                                     if i != axis))
            },
            external_edges=True)

        return nest, (outval, outidx)

    else:
        # map to result_type
        out, outarr = sdfg.add_temp_transient(
            sdfg.arrays[reduced_structs].shape, result_type)
        nest(_elementwise)("lambda x: x.idx", reduced_structs, out_array=out)
        return nest, out
Example #15
0
def _elementwise(sdfg: SDFG,
                 state: SDFGState,
                 func: str,
                 in_array: str,
                 out_array=None):
    """Apply a lambda function to each element in the input"""

    inparr = sdfg.arrays[in_array]
    restype = sdfg.arrays[in_array].dtype

    if out_array is None:
        out_array, outarr = sdfg.add_temp_transient(inparr.shape, restype,
                                                    inparr.storage)
    else:
        outarr = sdfg.arrays[out_array]

    func_ast = ast.parse(func)
    try:
        lambda_ast = func_ast.body[0].value
        if len(lambda_ast.args.args) != 1:
            raise SyntaxError(
                "Expected lambda with one arg, but {} has {}".format(
                    func, len(lambda_ast.args.arrgs)))
        arg = lambda_ast.args.args[0].arg
        body = astutils.unparse(lambda_ast.body)
    except AttributeError:
        raise SyntaxError("Could not parse func {}".format(func))

    code = "__out = {}".format(body)

    num_elements = reduce(lambda x, y: x * y, inparr.shape)
    if num_elements == 1:
        inp = state.add_read(in_array)
        out = state.add_write(out_array)
        tasklet = state.add_tasklet("_elementwise_", {arg}, {'__out'}, code)
        state.add_edge(inp, None, tasklet, arg,
                       Memlet.from_array(in_array, inparr))
        state.add_edge(tasklet, '__out', out, None,
                       Memlet.from_array(out_array, outarr))
    else:
        state.add_mapped_tasklet(
            name="_elementwise_",
            map_ranges={
                '__i%d' % i: '0:%s' % n
                for i, n in enumerate(inparr.shape)
            },
            inputs={
                arg:
                Memlet.simple(
                    in_array,
                    ','.join(['__i%d' % i for i in range(len(inparr.shape))]))
            },
            code=code,
            outputs={
                '__out':
                Memlet.simple(
                    out_array,
                    ','.join(['__i%d' % i for i in range(len(inparr.shape))]))
            },
            external_edges=True)

    return out_array
Example #16
0
def _bcgather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState,
              in_buffer: str, out_buffer: str,
              block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]]):

    from dace.libraries.pblas.nodes.pgeadd import BlockCyclicGather

    libnode = BlockCyclicGather('_BCGather_')

    inbuf_range = None
    if isinstance(in_buffer, tuple):
        inbuf_name, inbuf_range = in_buffer
    else:
        inbuf_name = in_buffer
    in_desc = sdfg.arrays[inbuf_name]
    inbuf_node = state.add_read(inbuf_name)

    bsizes_range = None
    if isinstance(block_sizes, (list, tuple)):
        if isinstance(block_sizes[0], str):
            bsizes_name, bsizes_range = block_sizes
            bsizes_desc = sdfg.arrays[bsizes_name]
            bsizes_node = state.add_read(bsizes_name)
        else:
            bsizes_name, bsizes_desc = sdfg.add_temp_transient(
                (len(block_sizes), ), dtype=dace.int32)
            bsizes_node = state.add_access(bsizes_name)
            bsizes_tasklet = state.add_tasklet(
                '_set_bsizes_', {}, {'__out'}, ";".join([
                    "__out[{}] = {}".format(i, sz)
                    for i, sz in enumerate(block_sizes)
                ]))
            state.add_edge(bsizes_tasklet, '__out', bsizes_node, None,
                           Memlet.from_array(bsizes_name, bsizes_desc))
    else:
        bsizes_name = block_sizes
        bsizes_desc = sdfg.arrays[bsizes_name]
        bsizes_node = state.add_read(bsizes_name)

    outbuf_range = None
    if isinstance(out_buffer, tuple):
        outbuf_name, outbuf_range = out_buffer
    else:
        outbuf_name = out_buffer
    out_desc = sdfg.arrays[outbuf_name]
    outbuf_node = state.add_write(outbuf_name)

    if inbuf_range:
        inbuf_mem = Memlet.simple(inbuf_name, inbuf_range)
    else:
        inbuf_mem = Memlet.from_array(inbuf_name, in_desc)
    if bsizes_range:
        bsizes_mem = Memlet.simple(bsizes_name, bsizes_range)
    else:
        bsizes_mem = Memlet.from_array(bsizes_name, bsizes_desc)
    if outbuf_range:
        outbuf_mem = Memlet.simple(outbuf_name, outbuf_range)
    else:
        outbuf_mem = Memlet.from_array(outbuf_name, out_desc)

    state.add_edge(inbuf_node, None, libnode, '_inbuffer', inbuf_mem)
    state.add_edge(bsizes_node, None, libnode, '_block_sizes', bsizes_mem)
    state.add_edge(libnode, '_outbuffer', outbuf_node, None, outbuf_mem)

    return None
Example #17
0
    def apply(self, sdfg: sd.SDFG):
        graph: sd.SDFGState = sdfg.nodes()[self.state_id]
        map_entry = graph.node(self.subgraph[DeduplicateAccess._map_entry])
        node1 = graph.node(self.subgraph[DeduplicateAccess._node1])
        node2 = graph.node(self.subgraph[DeduplicateAccess._node2])

        # Steps:
        # 1. Find unique subsets
        # 2. Find sets of contiguous subsets
        # 3. Create transients for subsets
        # 4. Redirect edges through new transients

        edges1 = set(e.src_conn for e in graph.edges_between(map_entry, node1))
        edges2 = set(e.src_conn for e in graph.edges_between(map_entry, node2))

        # Only apply to first connector (determinism)
        conn = sorted(edges1 & edges2)[0]

        edges = [e for e in graph.out_edges(map_entry) if e.src_conn == conn]

        # Get original data descriptor
        dname = edges[0].data.data
        desc = sdfg.arrays[edges[0].data.data]
        if isinstance(edges[0].dst,
                      nodes.AccessNode) and '15' in edges[0].dst.data:
            sdfg.save('faulty_dedup.sdfg')

        # Get unique subsets
        unique_subsets = set(e.data.subset for e in edges)

        # Find largest contiguous subsets
        try:
            # Start from stride-1 dimension
            contiguous_subsets = helpers.find_contiguous_subsets(
                unique_subsets,
                dim=next(i for i, s in enumerate(desc.strides) if s == 1))
        except (StopIteration, NotImplementedError):
            warnings.warn(
                "DeduplicateAcces::Not operating on Stride One Dimension!")
            contiguous_subsets = unique_subsets
        # Then find subsets for rest of the dimensions
        contiguous_subsets = helpers.find_contiguous_subsets(
            contiguous_subsets)
        # Map original edges to subsets
        edge_mapping = defaultdict(list)
        for e in edges:
            for ind, subset in enumerate(contiguous_subsets):
                if subset.covers(e.data.subset):
                    edge_mapping[ind].append(e)
                    break
            else:
                raise ValueError(
                    "Failed to find contiguous subset for edge %s" % e.data)

        # Create transients for subsets and redirect edges
        for ind, subset in enumerate(contiguous_subsets):
            name, _ = sdfg.add_temp_transient(subset.size(), desc.dtype)
            anode = graph.add_access(name)
            graph.add_edge(map_entry, conn, anode, None,
                           Memlet(data=dname, subset=subset))
            for e in edge_mapping[ind]:
                graph.remove_edge(e)
                new_memlet = copy.deepcopy(e.data)
                new_edge = graph.add_edge(anode, None, e.dst, e.dst_conn,
                                          new_memlet)
                for pe in graph.memlet_tree(new_edge):
                    # Rename data on memlet
                    pe.data.data = name
                    # Offset memlets to match new transient
                    pe.data.subset.offset(subset, True)