def scatter_add(data, indices, updates): """ Args: data: [x, y, z] indices: [n] updates: [n, y, z] Output: [x, y, z] """ left_shape = list(updates.shape[:1]) right_shape = list(updates.shape[1:]) def gen_ir(data, indices, updates, out): del data ib = tvm.ir_builder.create() with ib.for_range_n(left_shape, "i") as i: with ib.for_range_n(right_shape, "j") as j: idx_updates = i + j idx_data = [ib.load(indices, i)] + j temp = ib.load(updates, idx_updates) + ib.load(out, idx_data) ib.store(out, idx_data, temp) return ib.get() out_buf = tvm.decl_buffer(data.shape, data.dtype, "out_buf") return tvm.extern( [data.shape], [data, indices, updates], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name="fused_scatter_add", )
def csr_div(inputs, attrs): row_idx, col_idx, sparse_data, dense = inputs shape = tuple(attrs["dense_shape"]) feature_shape = get_shape(sparse_data.shape)[1:] assert dense.dtype == sparse_data.dtype, "data and weight must have the same dtype" num_rows = row_idx.shape[0] - 1 dense_shape = get_shape(dense.shape) sparse_shape = get_shape(shape) broadcast_shape = get_broadcast_shape(dense_shape, sparse_shape) need_expand = tvm.const(len(dense_shape) < len(broadcast_shape)) need_broadcast_first_dim = tvm.const( len(dense_shape) == len(broadcast_shape) and dense_shape[0] < broadcast_shape[0]) need_broadcast_last_dim = tvm.const( len(dense_shape) == len(broadcast_shape) and dense_shape[1] < broadcast_shape[1]) def gen_ir(dense, sparse_data, col_idx, row_idx, output): ib = tvm.ir_builder.create() ib.scope_attr("INFO", "csr_avg_row", int(sparse_data.shape[0]) // max(int(num_rows), 1)) with ib.for_range(0, num_rows, name='i') as i: start = ib.load(row_idx, i) end = ib.load(row_idx, i + 1) with ib.for_range(0, end - start, name='j') as j: pos = start + j with ib.for_range_n(feature_shape, 'k') as k: with ib.if_scope(pos < end): col = ib.load(col_idx, pos) store_loc = [pos] + k val = ib.load(sparse_data, store_loc) with ib.if_scope(need_expand): ib.store(output, store_loc, val / ib.load(dense, [col] + k)) with ib.else_scope(): with ib.if_scope(need_broadcast_first_dim): ib.store(output, store_loc, val / ib.load(dense, [0, col] + k)) with ib.else_scope(): with ib.if_scope(need_broadcast_last_dim): ib.store(output, store_loc, val / ib.load(dense, [i, 0] + k)) with ib.else_scope(): ib.store( output, store_loc, val / ib.load(dense, [i, col] + k)) return ib.get() output_name = "T_csr_div_" + dense.op.name + "_" + sparse_data.op.name out_buf = tvm.decl_buffer(sparse_data.shape, sparse_data.dtype, output_name) attrs = {"remove_self_dependence": True, "csr_op": True} return tvm.extern( [sparse_data.shape], [dense, sparse_data, col_idx, row_idx], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), dtype=sparse_data.dtype, out_buffers=[out_buf], name=output_name, attrs=attrs)
def gather(data, indices, axis, flag): """Only support axis=0.""" ndim = len(data.shape) axis = axis + ndim if axis < 0 else axis assert axis >= 0 assert axis < ndim data_shape = list(data.shape) indices_shape = list(indices.shape) output_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1:] left_shape = output_shape[:1] right_shape = output_shape[1:] def gen_ir(data, indices, out): ib = tvm.ir_builder.create() with ib.for_range_n(left_shape, 'i') as i: with ib.for_range_n(right_shape, 'j') as j: read_idx = [ib.load(indices, i)] val = ib.load(data, read_idx + j) ib.store(out, i + j, val) return ib.get() out_buf = tvm.decl_buffer(output_shape, data.dtype, "out_buf") return tvm.extern( [output_shape], [data, indices], lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name="fused_gather" + flag, )
def gather(inputs, attrs): attrs = {k: v for k, v in attrs.items()} axis = int(attrs["axis"][0]) if "axis" in attrs else 0 if len(inputs) != 2: raise ValueError(f"2 inputs expected, but got {len(inputs)}") data, indices = inputs data_shape = list(data.shape) indices_shape = list(indices.shape) output_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1:] def gen_ir(data, indices, out): ib = tvm.ir_builder.create() with ib.for_range_n(data_shape[:axis], "i") as i: with ib.for_range_n(indices_shape, "j") as j: load_idx = ib.load(indices, j) inbound = tvm.all(load_idx >= 0, load_idx < data_shape[axis]) read_idx = i + [load_idx] with ib.for_range_n(data_shape[axis + 1:], "k") as k: with ib.if_scope(inbound): ib.store(out, i + j + k, ib.load(data, read_idx + k)) with ib.else_scope(): ib.store(out, i + j + k, tvm.const(0, data.dtype)) return ib.get() output_name = "T_gather_" + data.op.name + "_" + indices.op.name + "_" + str( axis) out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name) return tvm.extern([data.shape], [data, indices], lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name=output_name)
def coo2csr(inputs, attrs): row_indices = inputs[0] height = int(attrs['height']) nnz = row_indices.shape[0] def gen_ir(row_indices, output): ib = tvm.ir_builder.create() with ib.for_range(0, height + 1, name='i') as i: ib.store(output, i, tvm.const(0, row_indices.dtype)) with ib.for_range(0, nnz, name='j') as j: row = ib.load(row_indices, j) with ib.if_scope(i > row): ptr = ib.load(output, i) ib.store(output, i, ptr + 1) return ib.get() output_name = "T_coo2csr_" + row_indices.op.name out_buf = tvm.decl_buffer(height + 1, row_indices.dtype, "output_data") return tvm.extern([height + 1], [row_indices], lambda ins, outs: gen_ir(ins[0], outs[0]), dtype=row_indices.dtype, out_buffers=[out_buf], name=output_name)
def csr2coo(inputs, attrs): indptr = inputs[0] num_rows = indptr.shape[0] - 1 nnz = int(attrs["nnz"]) def gen_ir(indptr, output): ib = tvm.ir_builder.create() ib.scope_attr("INFO", "csr_avg_row", nnz // max(int(num_rows), 1)) with ib.for_range(0, num_rows, name='i') as i: start = ib.load(indptr, i) end = ib.load(indptr, i + 1) with ib.for_range(0, end - start, name='j') as j: pos = start + j with ib.if_scope(pos < end): ib.store(output, pos, tvm.expr.Cast(indptr.dtype, i)) return ib.get() output_name = "T_csr2coo_" + indptr.op.name out_buf = tvm.decl_buffer(nnz, indptr.dtype, "output_data") attrs = {"csr_op": True} return tvm.extern([nnz], [indptr], lambda ins, outs: gen_ir(ins[0], outs[0]), dtype=indptr.dtype, out_buffers=[out_buf], name=output_name, attrs=attrs)
def csr_gather(inputs, attrs): row_idx, col_idx, dense = inputs num_rows = row_idx.shape[0] - 1 feature_shape = get_shape(dense.shape[2:]) def gen_ir(dense, col_idx, row_idx, output): ib = tvm.ir_builder.create() ib.scope_attr("INFO", "csr_avg_row", int(col_idx.shape[0]) // max(int(num_rows), 1)) with ib.for_range(0, num_rows, name='i') as i: start = ib.load(row_idx, i) end = ib.load(row_idx, i + 1) with ib.for_range(0, end - start, name='j') as j: pos = start + j with ib.for_range_n(feature_shape, 'k') as k: with ib.if_scope(pos < end): col = ib.load(col_idx, pos) ib.store(output, [pos] + k, ib.load(dense, [i, col] + k)) return ib.get() output_name = "T_csr_gather_" + dense.op.name output_shape = get_shape(col_idx.shape) + feature_shape out_buf = tvm.decl_buffer(output_shape, dense.dtype, "output_data") attrs = {"remove_self_dependence": True, "csr_op": True} return tvm.extern( [output_shape], [dense, col_idx, row_idx], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=dense.dtype, out_buffers=[out_buf], name=output_name, attrs=attrs)
def tensor_scatter_add(inputs, attrs): if len(inputs) != 3: raise ValueError(f"3 inputs expected, but got {len(inputs)}") data, indices, updates = inputs data_shape = list(data.shape) indices_shape = list(indices.shape) is_1d_indices = False if len(indices_shape) == 1: indices_shape.append(1) is_1d_indices = True left_shape = indices_shape[:-1] right_shape = data_shape[int(indices_shape[-1]):] def gen_ir(data, indices, updates, out): del data ib = tvm.ir_builder.create() with ib.for_range_n(left_shape, "i") as i: with ib.for_range_n(right_shape, "j") as j: index_read = i + j index_write = [] inbound = True if is_1d_indices: temp_idx = ib.load(indices, i) inbound = tvm.all((temp_idx >= 0), (temp_idx < data_shape[0])) index_write.append(temp_idx) else: for k in range(0, int(indices_shape[-1])): temp_idx = ib.load(indices, i + [k]) if k == 0: inbound = tvm.all((temp_idx >= 0), (temp_idx < data_shape[k])) else: inbound = tvm.all(inbound, (temp_idx >= 0), (temp_idx < data_shape[k])) index_write.append(temp_idx) index_write = index_write + j with ib.if_scope(inbound): temp = ib.load(updates, index_read) + ib.load( out, index_write) ib.store(out, index_write, temp) return ib.get() output_name = "T_tsa_" + data.op.name + "_" + indices.op.name + "_" + updates.op.name out_buf = tvm.decl_buffer(data.shape, data.dtype, output_name) attrs = {"disable_inline_inject": True} return tvm.extern( [data.shape], [data, indices, updates], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name=output_name, attrs=attrs)
def tensor_unsorted_segment_sum(inputs, attrs): attrs = {k: v for k, v in attrs.items()} num = attrs['num_segments'] op_id = attrs['op_id'] if 'op_id' in attrs else 0 if len(inputs) != 2: raise ValueError(f"2 inputs expected, but got {len(inputs)}") data, indices = inputs data_shape = list(data.shape) indices_shape = list(indices.shape) segment_len = len(data_shape) - len(indices_shape) if segment_len < 0: raise ValueError(f'input rank should not be less than segment_id rank') for i, v in enumerate(indices_shape): if int(v) != int(data_shape[i]): raise ValueError( f'input shape at dim {i} is not equal to segment_id shape at dim {i}' ) output_shape = [num] if segment_len > 0: output_shape += data_shape[len(indices_shape):] if len(indices_shape) > 1: raise ValueError('only 1-D segment currently supported') def gen_ir(data, indices, out): ib = tvm.ir_builder.create() with ib.for_range_n(indices_shape, "i") as i: read_idx = ib.load(indices, i) # 1-D segment with ib.for_range_n(data_shape[1:], 'j') as j: inbound = tvm.all((read_idx >= 0), (read_idx < num)) with ib.if_scope(inbound): val = ib.load(data, i + j) + ib.load(out, [read_idx] + j) ib.store(out, [read_idx] + j, val) return ib.get() output_name = "T_uss_" + data.op.name + "_" + indices.op.name out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name) attrs = {"disable_inline_inject": True} return tvm.extern([data.shape], [data, indices], lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name=output_name, attrs=attrs)
def csrmv(inputs, _): indptr, indices, data, weight = inputs assert len(data.shape) == 1 and len( weight.shape) == 2, "only supports 2-dim sparse tensor" assert data.dtype == weight.dtype, "data and weight must have same dtype." num_rows = indptr.shape[0] - 1 def csrmv_ir(data, indices, indptr, weight, out): ib = tvm.ir_builder.create() ib.scope_attr("INFO", "csr_avg_row", int(data.shape[0]) // max(int(num_rows), 1)) with ib.for_range(0, num_rows, name="row") as row: ib.store(out, [row, 0], tvm.const(0, data.dtype)) row_start = ib.load(indptr, row) row_end = ib.load(indptr, row + 1) row_elems = row_end - row_start with ib.for_range(0, row_elems, name="idx") as idx: elem = row_start + idx val = tvm.expr.Select( elem < row_end, ib.load(data, elem) * ib.load(weight, [ib.load(indices, elem), 0]), tvm.const(0, data.dtype)) ib.scope_attr( [tvm.api.iter_var_api( (0, weight.shape[0]), "idx", 2)], "reduce_update", "") temp = val + ib.load(out, [row, 0]) ib.store(out, [row, 0], temp) return ib.get() output_shape = [num_rows, 1] output_name = "T_csrmv_" + weight.op.name + "_" + data.op.name out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name) attrs = {"csr_op": True} return tvm.extern( [output_shape], [data, indices, indptr, weight], lambda ins, outs: csrmv_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name=output_name, attrs=attrs)
def csr_reduce_sum(inputs, attrs): row_idx, _, data = inputs # Currently, just support integer axis axis = int(attrs['axis'][0]) shape = tuple(attrs['dense_shape']) num_rows = row_idx.shape[0] - 1 if axis < 0: axis += len(shape) assert axis == 1, "only supports reduction of CSR axis 1" feature_shape = get_shape(data.shape)[1:] fused_shape = (shape[0], 1) + shape[2:] def gen_ir(data, row_idx, output): ib = tvm.ir_builder.create() ib.scope_attr("INFO", "csr_avg_row", int(data.shape[0]) // max(int(num_rows), 1)) with ib.for_range(0, num_rows, name="i") as i: start = ib.load(row_idx, i) end = ib.load(row_idx, i + 1) with ib.for_range_n(feature_shape, "k") as k: ib.store(output, [i, 0] + k, tvm.const(0, data.dtype)) with ib.for_range(0, end - start, name="j") as j: ib.scope_attr( [tvm.api.iter_var_api( (0, shape[1]), "j", 2)], "reduce_update", "") pos = start + j val = tvm.expr.Select(pos < end, ib.load(data, [pos] + k), tvm.const(0, data.dtype)) ib.store(output, [i, 0] + k, val + ib.load(output, [i, 0] + k)) return ib.get() output_shape = fused_shape output_name = "T_csr_reduce_sum_" + data.op.name + "_" + str(axis) out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name) attrs = {"csr_op": True, "fuse_axis_extern": True} return tvm.extern([output_shape], [data, row_idx], lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name=output_name, attrs=attrs)
def gather_nd(inputs, attrs): del attrs if len(inputs) != 2: raise ValueError(f"2 inputs expected, but got {len(inputs)}") data, indices = inputs data_shape = list(data.shape) indices_shape = list(indices.shape) indices_last_dim = len(indices_shape) - 1 left_shape = indices_shape[:indices_last_dim] right_shape = data_shape[int(indices_shape[indices_last_dim]):] def gen_ir(data, indices, out): ib = tvm.ir_builder.create() with ib.for_range_n(left_shape, 'i') as i: with ib.for_range_n(right_shape, 'j') as j: read_idx = [] inbound = True for k in range(0, int(indices_shape[-1])): temp_idx = ib.load(indices, i + [k]) if k == 0: inbound = tvm.all((temp_idx >= 0), (temp_idx < data_shape[k])) else: inbound = tvm.all(inbound, (temp_idx >= 0), (temp_idx < data_shape[k])) read_idx.append(temp_idx) with ib.if_scope(inbound): ib.store(out, i + j, ib.load(data, read_idx + j)) with ib.else_scope(): ib.store(out, i + j, tvm.const(0, data.dtype)) return ib.get() output_name = "T_gathernd_" + data.op.name + "_" + indices.op.name output_shape = left_shape + right_shape out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name) return tvm.extern([output_shape], [data, indices], lambda ins, outs: gen_ir(ins[0], ins[1], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name=output_name)
def csr_mm(inputs, _): indptr, indices, data, dense = inputs assert len(indptr.shape) == 1, "CSRTensor.indptr should be 1-dim." assert len(indices.shape) == 1, "CSRTensor.indices should be 1-dim." assert len(data.shape) == 1, "CSRTensor.values should be 1-dim." assert len(dense.shape) == 2, "Dense Tensor should be 2-dim." assert data.dtype == dense.dtype, "values and dense should have the same dtype." num_rows = indptr.shape[0] - 1 num_cols = dense.shape[1] def csr_mm_ir(indptr, indices, data, dense, out): ib = tvm.ir_builder.create() with ib.for_range(0, num_rows, name="row") as row: row_start = ib.load(indptr, row) row_end = ib.load(indptr, row + 1) num_eles = row_end - row_start with ib.for_range(0, num_cols, name="col") as col: ib.store(out, [row, col], tvm.const(0, data.dtype)) with ib.for_range(0, num_eles, name="strides") as strides: idx = row_start + strides val = tvm.expr.Select( idx < row_end, ib.load(data, idx) * ib.load(dense, [ib.load(indices, idx), col]), tvm.const(0, data.dtype)) ib.scope_attr( [tvm.api._IterVar((0, dense.shape[0]), "strides", 2)], "reduce_update", "") temp = val + ib.load(out, [row, col]) ib.store(out, [row, col], temp) return ib.get() output_shape = [num_rows, num_cols] output_name = "T_csr_mm_" + dense.op.name + "_" + data.op.name out_buf = tvm.decl_buffer(output_shape, data.dtype, output_name) return tvm.extern( [output_shape], [indptr, indices, data, dense], lambda ins, outs: csr_mm_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), dtype=data.dtype, out_buffers=[out_buf], name=output_name)
def standard_normal(inputs, attrs): del inputs attrs = {k: v for k, v in attrs.items()} seed = attrs["seed"] shape = attrs["shape"] dtype = "float32" def gen_ir(out): ib = tvm.ir_builder.create() with ib.for_range_n(shape, "i") as i: temp = ib.extern_call(seed, op_name="StandardNormal", dtype=dtype) ib.store(out, i, temp) return ib.get() output_name = "randnorm" out_buf = tvm.decl_buffer(shape, dtype, "res") return tvm.extern([shape], [], lambda ins, outs: gen_ir(outs[0]), dtype=dtype, out_buffers=[out_buf], name=output_name)