def apply_func(search_policy, state, stage_id): ret = [] s0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) if s0.stages[stage_id].op.tag == "sparse_dense_sp_rhs_bsrmm_block": return [s0.state_object, stage_id - 1] sparse_dense = s0.stages[stage_id].op sparse_dense_block = s0.stages[stage_id - 1].op assert sparse_dense.tag == "sparse_dense_sp_rhs_bsrmm" assert sparse_dense_block.tag == "sparse_dense_sp_rhs_bsrmm_block" # Set the default consumer of compute block consumer = sparse_dense # If sparse dense has a single elementwise consumer # We can compute inline the sparse_dense output stage consumers = _ffi_api.SearchPolicyUtilsGetConsumers( search_policy.search_task, s0.state_object, stage_id) if len(consumers) == 1: consumer_id = int(consumers.items()[0][0]) if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( search_policy.search_task, s0.state_object, stage_id, consumer_id): consumer = s0.stages[consumer_id].op s0.compute_inline(sparse_dense) nb_j, j, i, row_offset, c = s0[sparse_dense_block].iters n, m = s0[consumer].iters nb_n, n = s0.split(consumer, n, [j.range.extent]) j0, j1, j2 = s0.split(sparse_dense_block, nb_j, [None, None]) #j0, j1, j2, j3 = s0.split(sparse_dense_block, nb_j, [4, 2, 16]) n0, n1, n2 = s0.follow_split(consumer, nb_n, len(s0.transform_steps) - 1, 2) i0, i1, i2, i3, i4 = s0.split(sparse_dense_block, i, [None, None, None, None]) #i0, i1, i2, i3, i4 = s0.split(sparse_dense_block, i, [2, 4, 2, 4]) m0, m1, m2, m3 = s0.follow_split(consumer, m, len(s0.transform_steps) - 1, 3) c0, c1, c2 = s0.split(sparse_dense_block, c, [None, None]) s0.reorder(sparse_dense_block, [j0, i0, j1, i1, j2, i2, c0, row_offset, c1, i3, c2, j, i4]) s0.reorder(consumer, [n0, m0, n1, m1, n2, m2, n, m3]) s0.compute_at(sparse_dense_block, consumer, m2) #s0.bind(consumer, m0, 'blockIdx.x') #s0.bind(consumer, n0, 'blockIdx.y') #s0.bind(consumer, m1, 'threadIdx.x') #s0.bind(consumer, n1, 'threadIdx.y') print(s0) ret.append([s0.state_object, stage_id - 2]) return ret
def sparse_dense_apply_func(search_policy, state, stage_id): """Describe how to generate the initial sketch for sparse dense""" ret = [] s_0 = auto_scheduler.loop_state.State( state, search_policy.search_task.compute_dag) if s_0.stages[stage_id].op.tag == "sparse_dense_sp_rhs_bsrmm_block": return [s_0.state_object, stage_id - 1] sparse_dense = s_0.stages[stage_id].op sparse_dense_block = s_0.stages[stage_id - 1].op assert sparse_dense.tag == "sparse_dense_sp_rhs_bsrmm" assert sparse_dense_block.tag == "sparse_dense_sp_rhs_bsrmm_block" # Set the default consumer of compute block consumer = sparse_dense # If sparse dense has a single elementwise consumer # We can compute inline the sparse_dense output stage consumers = _ffi_api.SearchPolicyUtilsGetConsumers( search_policy.search_task, s_0.state_object, stage_id) if len(consumers) == 1: consumer_id = int(consumers.items()[0][0]) if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( search_policy.search_task, s_0.state_object, stage_id, consumer_id): consumer = s_0.stages[consumer_id].op s_0.compute_inline(sparse_dense) i, nb_j, j, row_offset, c = s_0[sparse_dense_block].iters m, n = s_0[consumer].iters i_0, i_1, i_2 = s_0.split(sparse_dense_block, i, [None, None]) m_0, m_1 = s_0.follow_split(consumer, m, len(s_0.transform_steps) - 1, 1) j_0, j_1 = s_0.split(sparse_dense_block, nb_j, [None]) n_0, n_1 = s_0.follow_split(consumer, n, len(s_0.transform_steps) - 1, 1) s_0.reorder(sparse_dense_block, [i_0, j_0, i_1, j_1, row_offset, i_2, j, c]) s_0.reorder(consumer, [m_0, n_0, m_1, n_1]) s_0.compute_at(sparse_dense_block, consumer, n_0) ret.append([s_0.state_object, stage_id - 2]) return ret
def apply_func(search_policy, state, stage_id): ret = [] s0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) if s0.stages[stage_id].op.tag[:-4] == "sparse_dense_v2_block": return [s0.state_object, stage_id - 1] sparse_dense = s0.stages[stage_id].op sparse_dense_block = s0.stages[stage_id - 1].op assert sparse_dense.tag[:-4] == "sparse_dense_v2" assert sparse_dense_block.tag[:-4] == "sparse_dense_v2_block" # Set the default consumer of compute block consumer = sparse_dense # If sparse dense has a single elementwise consumer # We can compute inline the sparse_dense output stage consumers = _ffi_api.SearchPolicyUtilsGetConsumers( search_policy.search_task, s0.state_object, stage_id) if len(consumers) == 1: consumer_id = int(consumers.items()[0][0]) if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( search_policy.search_task, s0.state_object, stage_id, consumer_id): consumer = s0.stages[consumer_id].op s0.compute_inline(sparse_dense) if ARGS.output_layout == 'hwc': i, nb_j, j, row_offset, c = s0[sparse_dense_block].iters m, n = s0[consumer].iters nb_n, n = s0.split(consumer, n, [j.range.extent]) i0, i1, i2, i3 = s0.split(sparse_dense_block, i, [None, None, None]) m0, m1, m2 = s0.follow_split(consumer, m, len(s0.transform_steps) - 1, 2) j0, j1 = s0.split(sparse_dense_block, nb_j, [None]) n0, n1 = s0.follow_split(consumer, nb_n, len(s0.transform_steps) - 1, 1) c0, c1, c2 = s0.split(sparse_dense_block, c, [None, None]) s0.reorder(sparse_dense_block, [i0, j0, i1, j1, c0, row_offset, c1, i2, c2, i3, j]) s0.reorder(consumer, [m0, n0, m1, n1, m2, n]) s0.compute_at(sparse_dense_block, consumer, n1) elif ARGS.output_layout == 'chw': nb_j, j, i, row_offset, c = s0[sparse_dense_block].iters n, m = s0[consumer].iters nb_n, n = s0.split(consumer, n, [j.range.extent]) i0, i1, i2, i3 = s0.split(sparse_dense_block, i, [None, None, None]) m0, m1, m2 = s0.follow_split(consumer, m, len(s0.transform_steps) - 1, 2) j0, j1 = s0.split(sparse_dense_block, nb_j, [None]) n0, n1 = s0.follow_split(consumer, nb_n, len(s0.transform_steps) - 1, 1) c0, c1, c2 = s0.split(sparse_dense_block, c, [None, None]) s0.reorder(sparse_dense_block, [j0, i0, j1, i1, c0, row_offset, c1, i2, c2, j, i3]) s0.reorder(consumer, [n0, m0, n1, m1, n, m2]) s0.compute_at(sparse_dense_block, consumer, m1) print(s0) ret.append([s0.state_object, stage_id - 2]) return ret
def sparse_conv2d_apply_func(search_policy, state, stage_id): """Describe how to generate the initial sketch for sparse conv2d""" ret = [] s_0 = auto_scheduler.loop_state.State( state, search_policy.search_task.compute_dag) if s_0.stages[stage_id].op.tag == "sparse_conv2d_sp_bsrmm_block": return [s_0.state_object, stage_id - 1] sparse_conv2d = s_0.stages[stage_id].op sparse_conv2d_block = s_0.stages[stage_id - 1].op assert sparse_conv2d.tag == "sparse_conv2d_sp_bsrmm" assert sparse_conv2d_block.tag == "sparse_conv2d_sp_bsrmm_block" layout = sparse_conv2d.attrs["layout"] # Set the default consumer of compute block consumer = sparse_conv2d # If sparse conv2d has a single elementwise consumer # We can compute inline the sparse_conv2d output stage consumers = _ffi_api.SearchPolicyUtilsGetConsumers( search_policy.search_task, s_0.state_object, stage_id) if len(consumers) == 1: consumer_id = int(consumers.items()[0][0]) if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( search_policy.search_task, s_0.state_object, stage_id, consumer_id): consumer = s_0.stages[consumer_id].op s_0.compute_inline(sparse_conv2d) c = None if layout == "NHWC": if len(s_0[sparse_conv2d_block].iters) == 6: # bs_c = 1 i, h, w, nb_j, j, row_offset = s_0[ # pylint: disable=invalid-name sparse_conv2d_block].iters else: i, h, w, nb_j, j, row_offset, c = s_0[ # pylint: disable=invalid-name sparse_conv2d_block].iters m, x, y, n = s_0[consumer].iters elif layout == "NCHW": if len(s_0[sparse_conv2d_block].iters) == 6: # bs_c = 1 i, nb_j, j, h, w, row_offset = s_0[ # pylint: disable=invalid-name sparse_conv2d_block].iters else: i, nb_j, j, h, w, row_offset, c = s_0[ # pylint: disable=invalid-name sparse_conv2d_block].iters m, n, x, y = s_0[consumer].iters i_0, i_1, i_2 = s_0.split(sparse_conv2d_block, i, [None, None]) m_0, m_1 = s_0.follow_split(consumer, m, len(s_0.transform_steps) - 1, 1) h_0, h_1, h_2 = s_0.split(sparse_conv2d_block, h, [None, None]) x_0, x_1 = s_0.follow_split(consumer, x, len(s_0.transform_steps) - 1, 1) w_0, w_1, w_2 = s_0.split(sparse_conv2d_block, w, [None, None]) # pylint: disable=invalid-name y_0, y_1 = s_0.follow_split(consumer, y, len(s_0.transform_steps) - 1, 1) j_0, j_1 = s_0.split(sparse_conv2d_block, nb_j, [None]) n_0, n_1 = s_0.follow_split(consumer, n, len(s_0.transform_steps) - 1, 1) if layout == "NHWC": if c is None: s_0.reorder( sparse_conv2d_block, [ i_0, h_0, w_0, j_0, i_1, h_1, w_1, j_1, row_offset, i_2, h_2, w_2, j ], ) else: s_0.reorder( sparse_conv2d_block, [ i_0, h_0, w_0, j_0, i_1, h_1, w_1, j_1, row_offset, i_2, h_2, w_2, j, c ], ) s_0.reorder(consumer, [m_0, x_0, y_0, n_0, m_1, x_1, y_1, n_1]) elif layout == "NCHW": if c is None: s_0.reorder( sparse_conv2d_block, [ i_0, j_0, h_0, w_0, i_1, j_1, h_1, w_1, row_offset, i_2, j, h_2, w_2 ], ) else: s_0.reorder( sparse_conv2d_block, [ i_0, j_0, h_0, w_0, i_1, j_1, h_1, w_1, row_offset, i_2, j, c, h_2, w_2 ], ) s_0.reorder(consumer, [m_0, n_0, x_0, y_0, m_1, n_1, x_1, y_1]) s_0.compute_at(sparse_conv2d_block, consumer, n_0) ret.append([s_0.state_object, stage_id - 2]) return ret