def getdata_kron(expr, arg_data): return [ AtomData(expr, arg_data, macro_name="kron", sparsity=sp.kron(arg_data[0].sparsity, arg_data[1].sparsity), work_int=arg_data[0].sparsity.shape[1], work_float=arg_data[0].sparsity.shape[1]) ]
def getdata_mul(expr, arg_data): if arg_data[0].size == (1,1): return [AtomData(expr, arg_data, inplace = True, copy_arg = 1, macro_name = "scalar_mul", sparsity = arg_data[1].sparsity)] if arg_data[1].size == (1,1): return [AtomData(expr, arg_data, inplace = True, copy_arg = 0, macro_name = "scalar_rmul", sparsity = arg_data[0].sparsity)] else: return [AtomData(expr, arg_data, macro_name = "mul", sparsity = arg_data[0].sparsity * arg_data[1].sparsity, work_int = arg_data[0].sparsity.shape[1], work_float = arg_data[0].sparsity.shape[1])]
def getdata_hstack(expr, arg_data): data_list = [] sparsity = arg_data[0].sparsity data = arg_data[0] for i, arg in enumerate(arg_data[1:]): sparsity = sp.hstack([sparsity, arg.sparsity]) data = AtomData(expr, [data, arg], macro_name = "hstack", sparsity = sparsity, size = sparsity.shape) data_list += [data] return data_list
def getdata_diag_vec(expr, arg_data): m = arg_data[0].size[0] n = arg_data[0].size[1] sp_mat = sp.coo_matrix(arg_data[0].sparsity) data = sp_mat.data macro_name = "diag_vec" idxs = sp_mat.row shape = (m, m) sparsity = sp.csr_matrix(sp.coo_matrix((data, (idxs, idxs)), shape=shape)) return [AtomData(expr, arg_data, macro_name=macro_name, sparsity=sparsity)]
def getdata_reshape(expr, arg_data): m_new, n_new = expr.get_data() m = arg_data[0].size[0] n = arg_data[0].size[1] sparsity = reshape(arg_data[0].sparsity, m_new, n_new) return [AtomData(expr, arg_data, macro_name = 'reshape', sparsity = sparsity, work_int = m_new, work_float = m_new, data = (m_new, n_new))]
def getdata_index(expr, arg_data): slices = expr.get_data()[0] start0 = 0 if slices[0].start == None else slices[0].start start1 = 0 if slices[1].start == None else slices[1].start stop0 = arg_data[0].sparsity.shape[0] if slices[ 0].stop == None else slices[0].stop stop1 = arg_data[0].sparsity.shape[1] if slices[ 1].stop == None else slices[1].stop step0 = 1 if slices[0].step == None else slices[0].step step1 = 1 if slices[1].step == None else slices[1].step if start0 < 0 or stop0 > arg_data[0].size[0]: raise ValueError("First index out of bounds") if start1 < 0 or stop1 > arg_data[0].size[1]: print('\n arg size:', arg_data[0].size[1]) print('\n stop:', stop1) print('\n start', start1) raise ValueError("Second index out of bounds") data = { 'start0': start0, 'stop0': stop0, 'step0': step0, 'start1': start1, 'stop1': stop1, 'step1': step1 } sparsity = arg_data[0].sparsity[start0:stop0:step0, start1:stop1:step1] return [ AtomData(expr, arg_data, macro_name='index', sparsity=sparsity, data=data) ]
def getdata_add(expr, arg_data): data = arg_data[0] sparsity = arg_data[0].sparsity data_list = [] for i, arg in enumerate(arg_data[1:]): if data.size == (1,1): sparsity = arg.sparsity macro_name = 'scalar_add' work_int = 0 work_float = 0 inplace = True copy_arg = 1 elif arg.size == (1,1): sparsity = data.sparsity macro_name = 'scalar_radd' work_int = 0 work_float = 0 inplace = True copy_arg = 0 else: sparsity += arg.sparsity macro_name = 'add' work_int = arg.sparsity.shape[1] work_float = arg.sparsity.shape[1] inplace = False copy_arg = 0 data = AtomData(expr, arg_data = [data, arg], macro_name = macro_name, sparsity = sparsity, work_int = work_int, work_float = work_float, inplace = inplace, copy_arg = copy_arg) data_list += [data] return data_list
def getdata_mul_elemwise(expr, arg_data): return [AtomData(expr, arg_data, macro_name = "mul_elemwise", sparsity = arg_data[0].sparsity.multiply(arg_data[1].sparsity), work_int = arg_data[0].sparsity.shape[1], work_float = arg_data[0].sparsity.shape[1])]
def getdata_max_entries(expr, arg_data): return [AtomData(expr, arg_data, macro_name="max_entries", inplace=False)]
def getdata_diag_mat(expr, arg_data): sparsity = sp.csr_matrix(arg_data[0].sparsity.diagonal()).T return [AtomData(expr, arg_data, macro_name='diag_mat', sparsity=sparsity)]
def getdata_trace(expr, arg_data): return [AtomData(expr, arg_data, macro_name="trace")]
def getdata_neg(expr, arg_data): return [AtomData(expr, arg_data, macro_name = "neg", sparsity = arg_data[0].sparsity, inplace = True)]