Ejemplo n.º 1
0
def scatter_add(data, indices, updates):
    """
    Args:
        data: [x, y, z]
        indices: [n]
        updates: [n, y, z]
    Output:
        [x, y, z]
    """
    left_shape = list(updates.shape[:1])
    right_shape = list(updates.shape[1:])

    def gen_ir(data, indices, updates, out):
        del data
        ib = tvm.ir_builder.create()
        with ib.for_range_n(left_shape, "i") as i:
            with ib.for_range_n(right_shape, "j") as j:
                idx_updates = i + j
                idx_data = [ib.load(indices, i)] + j
                temp = ib.load(updates, idx_updates) + ib.load(out, idx_data)
                ib.store(out, idx_data, temp)
        return ib.get()

    out_buf = tvm.decl_buffer(data.shape, data.dtype, "out_buf")
    return tvm.extern(
        [data.shape],
        [data, indices, updates],
        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
        dtype=data.dtype,
        out_buffers=[out_buf],
        name="fused_scatter_add",
    )
Ejemplo n.º 2
0
def csr_div(inputs, attrs):
    row_idx, col_idx, sparse_data, dense = inputs
    shape = tuple(attrs["dense_shape"])
    feature_shape = get_shape(sparse_data.shape)[1:]
    assert dense.dtype == sparse_data.dtype, "data and weight must have the same dtype"

    num_rows = row_idx.shape[0] - 1
    dense_shape = get_shape(dense.shape)
    sparse_shape = get_shape(shape)
    broadcast_shape = get_broadcast_shape(dense_shape, sparse_shape)
    need_expand = tvm.const(len(dense_shape) < len(broadcast_shape))
    need_broadcast_first_dim = tvm.const(
        len(dense_shape) == len(broadcast_shape)
        and dense_shape[0] < broadcast_shape[0])
    need_broadcast_last_dim = tvm.const(
        len(dense_shape) == len(broadcast_shape)
        and dense_shape[1] < broadcast_shape[1])

    def gen_ir(dense, sparse_data, col_idx, row_idx, output):
        ib = tvm.ir_builder.create()
        ib.scope_attr("INFO", "csr_avg_row",
                      int(sparse_data.shape[0]) // max(int(num_rows), 1))
        with ib.for_range(0, num_rows, name='i') as i:
            start = ib.load(row_idx, i)
            end = ib.load(row_idx, i + 1)
            with ib.for_range(0, end - start, name='j') as j:
                pos = start + j
                with ib.for_range_n(feature_shape, 'k') as k:
                    with ib.if_scope(pos < end):
                        col = ib.load(col_idx, pos)
                        store_loc = [pos] + k
                        val = ib.load(sparse_data, store_loc)
                        with ib.if_scope(need_expand):
                            ib.store(output, store_loc,
                                     val / ib.load(dense, [col] + k))
                        with ib.else_scope():
                            with ib.if_scope(need_broadcast_first_dim):
                                ib.store(output, store_loc,
                                         val / ib.load(dense, [0, col] + k))
                            with ib.else_scope():
                                with ib.if_scope(need_broadcast_last_dim):
                                    ib.store(output, store_loc,
                                             val / ib.load(dense, [i, 0] + k))
                                with ib.else_scope():
                                    ib.store(
                                        output, store_loc,
                                        val / ib.load(dense, [i, col] + k))
        return ib.get()

    output_name = "T_csr_div_" + dense.op.name + "_" + sparse_data.op.name
    out_buf = tvm.decl_buffer(sparse_data.shape, sparse_data.dtype,
                              output_name)
    attrs = {"remove_self_dependence": True, "csr_op": True}
    return tvm.extern(
        [sparse_data.shape], [dense, sparse_data, col_idx, row_idx],
        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
        dtype=sparse_data.dtype,
        out_buffers=[out_buf],
        name=output_name,
        attrs=attrs)
Ejemplo n.º 3
0
def gather(data, indices, axis, flag):
    """Only support axis=0."""
    ndim = len(data.shape)
    axis = axis + ndim if axis < 0 else axis
    assert axis >= 0
    assert axis < ndim

    data_shape = list(data.shape)
    indices_shape = list(indices.shape)
    output_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1:]
    left_shape = output_shape[:1]
    right_shape = output_shape[1:]

    def gen_ir(data, indices, out):
        ib = tvm.ir_builder.create()
        with ib.for_range_n(left_shape, 'i') as i:
            with ib.for_range_n(right_shape, 'j') as j:
                read_idx = [ib.load(indices, i)]
                val = ib.load(data, read_idx + j)
                ib.store(out, i + j, val)
        return ib.get()

    out_buf = tvm.decl_buffer(output_shape, data.dtype, "out_buf")

    return tvm.extern(
        [output_shape],
        [data, indices],
        lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
        dtype=data.dtype,
        out_buffers=[out_buf],
        name="fused_gather" + flag,
    )
Ejemplo n.º 4
0
def gather(inputs, attrs):
    attrs = {k: v for k, v in attrs.items()}
    axis = int(attrs["axis"][0]) if "axis" in attrs else 0
    if len(inputs) != 2:
        raise ValueError(f"2 inputs expected, but got {len(inputs)}")
    data, indices = inputs
    data_shape = list(data.shape)
    indices_shape = list(indices.shape)
    output_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1:]

    def gen_ir(data, indices, out):
        ib = tvm.ir_builder.create()
        with ib.for_range_n(data_shape[:axis], "i") as i:
            with ib.for_range_n(indices_shape, "j") as j:
                load_idx = ib.load(indices, j)
                inbound = tvm.all(load_idx >= 0, load_idx < data_shape[axis])
                read_idx = i + [load_idx]
                with ib.for_range_n(data_shape[axis + 1:], "k") as k:
                    with ib.if_scope(inbound):
                        ib.store(out, i + j + k, ib.load(data, read_idx + k))
                    with ib.else_scope():
                        ib.store(out, i + j + k, tvm.const(0, data.dtype))
        return ib.get()

    output_name = "T_gather_" + data.op.name + "_" + indices.op.name + "_" + str(
        axis)
    out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name)
    return tvm.extern([data.shape], [data, indices],
                      lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
                      dtype=data.dtype,
                      out_buffers=[out_buf],
                      name=output_name)
Ejemplo n.º 5
0
def coo2csr(inputs, attrs):
    row_indices = inputs[0]
    height = int(attrs['height'])
    nnz = row_indices.shape[0]

    def gen_ir(row_indices, output):
        ib = tvm.ir_builder.create()
        with ib.for_range(0, height + 1, name='i') as i:
            ib.store(output, i, tvm.const(0, row_indices.dtype))
            with ib.for_range(0, nnz, name='j') as j:
                row = ib.load(row_indices, j)
                with ib.if_scope(i > row):
                    ptr = ib.load(output, i)
                    ib.store(output, i, ptr + 1)
        return ib.get()

    output_name = "T_coo2csr_" + row_indices.op.name

    out_buf = tvm.decl_buffer(height + 1, row_indices.dtype, "output_data")

    return tvm.extern([height + 1], [row_indices],
                      lambda ins, outs: gen_ir(ins[0], outs[0]),
                      dtype=row_indices.dtype,
                      out_buffers=[out_buf],
                      name=output_name)
Ejemplo n.º 6
0
def csr2coo(inputs, attrs):
    indptr = inputs[0]
    num_rows = indptr.shape[0] - 1
    nnz = int(attrs["nnz"])

    def gen_ir(indptr, output):
        ib = tvm.ir_builder.create()
        ib.scope_attr("INFO", "csr_avg_row", nnz // max(int(num_rows), 1))
        with ib.for_range(0, num_rows, name='i') as i:
            start = ib.load(indptr, i)
            end = ib.load(indptr, i + 1)
            with ib.for_range(0, end - start, name='j') as j:
                pos = start + j
                with ib.if_scope(pos < end):
                    ib.store(output, pos, tvm.expr.Cast(indptr.dtype, i))
        return ib.get()

    output_name = "T_csr2coo_" + indptr.op.name

    out_buf = tvm.decl_buffer(nnz, indptr.dtype, "output_data")
    attrs = {"csr_op": True}

    return tvm.extern([nnz], [indptr],
                      lambda ins, outs: gen_ir(ins[0], outs[0]),
                      dtype=indptr.dtype,
                      out_buffers=[out_buf],
                      name=output_name,
                      attrs=attrs)
Ejemplo n.º 7
0
def csr_gather(inputs, attrs):
    row_idx, col_idx, dense = inputs

    num_rows = row_idx.shape[0] - 1
    feature_shape = get_shape(dense.shape[2:])

    def gen_ir(dense, col_idx, row_idx, output):
        ib = tvm.ir_builder.create()
        ib.scope_attr("INFO", "csr_avg_row",
                      int(col_idx.shape[0]) // max(int(num_rows), 1))
        with ib.for_range(0, num_rows, name='i') as i:
            start = ib.load(row_idx, i)
            end = ib.load(row_idx, i + 1)
            with ib.for_range(0, end - start, name='j') as j:
                pos = start + j
                with ib.for_range_n(feature_shape, 'k') as k:
                    with ib.if_scope(pos < end):
                        col = ib.load(col_idx, pos)
                        ib.store(output, [pos] + k,
                                 ib.load(dense, [i, col] + k))
        return ib.get()

    output_name = "T_csr_gather_" + dense.op.name
    output_shape = get_shape(col_idx.shape) + feature_shape
    out_buf = tvm.decl_buffer(output_shape, dense.dtype, "output_data")
    attrs = {"remove_self_dependence": True, "csr_op": True}
    return tvm.extern(
        [output_shape], [dense, col_idx, row_idx],
        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
        dtype=dense.dtype,
        out_buffers=[out_buf],
        name=output_name,
        attrs=attrs)
Ejemplo n.º 8
0
def tensor_scatter_add(inputs, attrs):
    if len(inputs) != 3:
        raise ValueError(f"3 inputs expected, but got {len(inputs)}")
    data, indices, updates = inputs
    data_shape = list(data.shape)
    indices_shape = list(indices.shape)
    is_1d_indices = False
    if len(indices_shape) == 1:
        indices_shape.append(1)
        is_1d_indices = True
    left_shape = indices_shape[:-1]
    right_shape = data_shape[int(indices_shape[-1]):]

    def gen_ir(data, indices, updates, out):
        del data
        ib = tvm.ir_builder.create()
        with ib.for_range_n(left_shape, "i") as i:
            with ib.for_range_n(right_shape, "j") as j:
                index_read = i + j
                index_write = []
                inbound = True
                if is_1d_indices:
                    temp_idx = ib.load(indices, i)
                    inbound = tvm.all((temp_idx >= 0),
                                      (temp_idx < data_shape[0]))
                    index_write.append(temp_idx)
                else:
                    for k in range(0, int(indices_shape[-1])):
                        temp_idx = ib.load(indices, i + [k])
                        if k == 0:
                            inbound = tvm.all((temp_idx >= 0),
                                              (temp_idx < data_shape[k]))
                        else:
                            inbound = tvm.all(inbound, (temp_idx >= 0),
                                              (temp_idx < data_shape[k]))
                        index_write.append(temp_idx)
                index_write = index_write + j
                with ib.if_scope(inbound):
                    temp = ib.load(updates, index_read) + ib.load(
                        out, index_write)
                    ib.store(out, index_write, temp)
        return ib.get()

    output_name = "T_tsa_" + data.op.name + "_" + indices.op.name + "_" + updates.op.name
    out_buf = tvm.decl_buffer(data.shape, data.dtype, output_name)
    attrs = {"disable_inline_inject": True}
    return tvm.extern(
        [data.shape], [data, indices, updates],
        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
        dtype=data.dtype,
        out_buffers=[out_buf],
        name=output_name,
        attrs=attrs)
Ejemplo n.º 9
0
def tensor_unsorted_segment_sum(inputs, attrs):
    attrs = {k: v for k, v in attrs.items()}
    num = attrs['num_segments']
    op_id = attrs['op_id'] if 'op_id' in attrs else 0
    if len(inputs) != 2:
        raise ValueError(f"2 inputs expected, but got {len(inputs)}")
    data, indices = inputs
    data_shape = list(data.shape)
    indices_shape = list(indices.shape)
    segment_len = len(data_shape) - len(indices_shape)
    if segment_len < 0:
        raise ValueError(f'input rank should not be less than segment_id rank')
    for i, v in enumerate(indices_shape):
        if int(v) != int(data_shape[i]):
            raise ValueError(
                f'input shape at dim {i} is not equal to segment_id shape at dim {i}'
            )
    output_shape = [num]
    if segment_len > 0:
        output_shape += data_shape[len(indices_shape):]
    if len(indices_shape) > 1:
        raise ValueError('only 1-D segment currently supported')

    def gen_ir(data, indices, out):
        ib = tvm.ir_builder.create()
        with ib.for_range_n(indices_shape, "i") as i:
            read_idx = ib.load(indices, i)
            # 1-D segment
            with ib.for_range_n(data_shape[1:], 'j') as j:
                inbound = tvm.all((read_idx >= 0), (read_idx < num))
                with ib.if_scope(inbound):
                    val = ib.load(data, i + j) + ib.load(out, [read_idx] + j)
                    ib.store(out, [read_idx] + j, val)
        return ib.get()

    output_name = "T_uss_" + data.op.name + "_" + indices.op.name
    out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name)
    attrs = {"disable_inline_inject": True}
    return tvm.extern([data.shape], [data, indices],
                      lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
                      dtype=data.dtype,
                      out_buffers=[out_buf],
                      name=output_name,
                      attrs=attrs)
Ejemplo n.º 10
0
def csrmv(inputs, _):
    indptr, indices, data, weight = inputs
    assert len(data.shape) == 1 and len(
        weight.shape) == 2, "only supports 2-dim sparse tensor"
    assert data.dtype == weight.dtype, "data and weight must have same dtype."

    num_rows = indptr.shape[0] - 1

    def csrmv_ir(data, indices, indptr, weight, out):
        ib = tvm.ir_builder.create()
        ib.scope_attr("INFO", "csr_avg_row",
                      int(data.shape[0]) // max(int(num_rows), 1))
        with ib.for_range(0, num_rows, name="row") as row:
            ib.store(out, [row, 0], tvm.const(0, data.dtype))
            row_start = ib.load(indptr, row)
            row_end = ib.load(indptr, row + 1)
            row_elems = row_end - row_start
            with ib.for_range(0, row_elems, name="idx") as idx:
                elem = row_start + idx
                val = tvm.expr.Select(
                    elem < row_end,
                    ib.load(data, elem) *
                    ib.load(weight, [ib.load(indices, elem), 0]),
                    tvm.const(0, data.dtype))
                ib.scope_attr(
                    [tvm.api.iter_var_api(
                        (0, weight.shape[0]), "idx", 2)], "reduce_update", "")
                temp = val + ib.load(out, [row, 0])
                ib.store(out, [row, 0], temp)
        return ib.get()

    output_shape = [num_rows, 1]
    output_name = "T_csrmv_" + weight.op.name + "_" + data.op.name
    out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name)
    attrs = {"csr_op": True}
    return tvm.extern(
        [output_shape], [data, indices, indptr, weight],
        lambda ins, outs: csrmv_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
        dtype=data.dtype,
        out_buffers=[out_buf],
        name=output_name,
        attrs=attrs)
Ejemplo n.º 11
0
def csr_reduce_sum(inputs, attrs):
    row_idx, _, data = inputs
    # Currently, just support integer axis
    axis = int(attrs['axis'][0])
    shape = tuple(attrs['dense_shape'])
    num_rows = row_idx.shape[0] - 1
    if axis < 0:
        axis += len(shape)
    assert axis == 1, "only supports reduction of CSR axis 1"
    feature_shape = get_shape(data.shape)[1:]
    fused_shape = (shape[0], 1) + shape[2:]

    def gen_ir(data, row_idx, output):
        ib = tvm.ir_builder.create()
        ib.scope_attr("INFO", "csr_avg_row",
                      int(data.shape[0]) // max(int(num_rows), 1))
        with ib.for_range(0, num_rows, name="i") as i:
            start = ib.load(row_idx, i)
            end = ib.load(row_idx, i + 1)
            with ib.for_range_n(feature_shape, "k") as k:
                ib.store(output, [i, 0] + k, tvm.const(0, data.dtype))
                with ib.for_range(0, end - start, name="j") as j:
                    ib.scope_attr(
                        [tvm.api.iter_var_api(
                            (0, shape[1]), "j", 2)], "reduce_update", "")
                    pos = start + j
                    val = tvm.expr.Select(pos < end, ib.load(data, [pos] + k),
                                          tvm.const(0, data.dtype))
                    ib.store(output, [i, 0] + k,
                             val + ib.load(output, [i, 0] + k))
        return ib.get()

    output_shape = fused_shape
    output_name = "T_csr_reduce_sum_" + data.op.name + "_" + str(axis)
    out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name)
    attrs = {"csr_op": True, "fuse_axis_extern": True}
    return tvm.extern([output_shape], [data, row_idx],
                      lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
                      dtype=data.dtype,
                      out_buffers=[out_buf],
                      name=output_name,
                      attrs=attrs)
Ejemplo n.º 12
0
def gather_nd(inputs, attrs):
    del attrs
    if len(inputs) != 2:
        raise ValueError(f"2 inputs expected, but got {len(inputs)}")
    data, indices = inputs

    data_shape = list(data.shape)
    indices_shape = list(indices.shape)
    indices_last_dim = len(indices_shape) - 1
    left_shape = indices_shape[:indices_last_dim]
    right_shape = data_shape[int(indices_shape[indices_last_dim]):]

    def gen_ir(data, indices, out):
        ib = tvm.ir_builder.create()
        with ib.for_range_n(left_shape, 'i') as i:
            with ib.for_range_n(right_shape, 'j') as j:
                read_idx = []
                inbound = True
                for k in range(0, int(indices_shape[-1])):
                    temp_idx = ib.load(indices, i + [k])
                    if k == 0:
                        inbound = tvm.all((temp_idx >= 0),
                                          (temp_idx < data_shape[k]))
                    else:
                        inbound = tvm.all(inbound, (temp_idx >= 0),
                                          (temp_idx < data_shape[k]))
                    read_idx.append(temp_idx)
                with ib.if_scope(inbound):
                    ib.store(out, i + j, ib.load(data, read_idx + j))
                with ib.else_scope():
                    ib.store(out, i + j, tvm.const(0, data.dtype))
        return ib.get()

    output_name = "T_gathernd_" + data.op.name + "_" + indices.op.name
    output_shape = left_shape + right_shape
    out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name)
    return tvm.extern([output_shape], [data, indices],
                      lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]),
                      dtype=data.dtype,
                      out_buffers=[out_buf],
                      name=output_name)
Ejemplo n.º 13
0
def csr_mm(inputs, _):
    indptr, indices, data, dense = inputs
    assert len(indptr.shape) == 1, "CSRTensor.indptr should be 1-dim."
    assert len(indices.shape) == 1, "CSRTensor.indices should be 1-dim."
    assert len(data.shape) == 1, "CSRTensor.values should be 1-dim."
    assert len(dense.shape) == 2, "Dense Tensor should be 2-dim."
    assert data.dtype == dense.dtype, "values and dense should have the same dtype."
    num_rows = indptr.shape[0] - 1
    num_cols = dense.shape[1]

    def csr_mm_ir(indptr, indices, data, dense, out):
        ib = tvm.ir_builder.create()
        with ib.for_range(0, num_rows, name="row") as row:
            row_start = ib.load(indptr, row)
            row_end = ib.load(indptr, row + 1)
            num_eles = row_end - row_start
            with ib.for_range(0, num_cols, name="col") as col:
                ib.store(out, [row, col], tvm.const(0, data.dtype))
                with ib.for_range(0, num_eles, name="strides") as strides:
                    idx = row_start + strides
                    val = tvm.expr.Select(
                        idx < row_end,
                        ib.load(data, idx) *
                        ib.load(dense, [ib.load(indices, idx), col]),
                        tvm.const(0, data.dtype))
                    ib.scope_attr(
                        [tvm.api._IterVar((0, dense.shape[0]), "strides", 2)],
                        "reduce_update", "")
                    temp = val + ib.load(out, [row, col])
                    ib.store(out, [row, col], temp)
        return ib.get()

    output_shape = [num_rows, num_cols]
    output_name = "T_csr_mm_" + dense.op.name + "_" + data.op.name
    out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name)
    return tvm.extern(
        [output_shape], [indptr, indices, data, dense],
        lambda ins, outs: csr_mm_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
        dtype=data.dtype,
        out_buffers=[out_buf],
        name=output_name)
Ejemplo n.º 14
0
def standard_normal(inputs, attrs):
    del inputs
    attrs = {k: v for k, v in attrs.items()}
    seed = attrs["seed"]
    shape = attrs["shape"]
    dtype = "float32"

    def gen_ir(out):
        ib = tvm.ir_builder.create()
        with ib.for_range_n(shape, "i") as i:
            temp = ib.extern_call(seed, op_name="StandardNormal", dtype=dtype)
            ib.store(out, i, temp)
        return ib.get()

    output_name = "randnorm"
    out_buf = tvm.decl_buffer(shape, dtype, "res")
    return tvm.extern([shape], [],
                      lambda ins, outs: gen_ir(outs[0]),
                      dtype=dtype,
                      out_buffers=[out_buf],
                      name=output_name)