def __init__(self, vm: VM, trans_a: bool, trans_b: bool, op1: Symbol, op2: Symbol, dest: Symbol): super(Gemm, self).__init__(vm) self._trans_a = trans_a self._trans_b = trans_b self._op1 = op1 self._op2 = op2 self._is_ready = True self.registers = None if dest.stype != SymbolType.Register: raise InternalError( f'gemm: accumulator-register array is not provided. Instead: {dest.stype}' ) else: self._dest = dest if not isinstance(self._op1.obj, Matrix): raise InternalError('gemm: op1 is not a matrix') if not isinstance(self._op2.obj, Matrix): raise InternalError('gemm: op2 is not a matrix') op1.add_user(self) op2.add_user(self) dest.add_user(self)
def __init__(self, vm: VM, src: Symbol, dest: Symbol): super(GetElementPtr, self).__init__(vm) if src.stype != SymbolType.Batch: raise InternalError('ptr: operand `src` is not in a batch') if not isinstance(src.obj, Matrix): raise InternalError(f'ptr: operand `src` is not a matrix') if dest.stype != SymbolType.Global: raise InternalError('ptr: operand `dest` is not in global mem.') if not isinstance(dest.obj, Matrix): raise InternalError('ptr: operand `dest` is not a matrix') dest.data_view = DataView(rows=src.obj.get_actual_num_rows(), columns=src.obj.get_actual_num_cols(), lead_dim=src.obj.num_rows, is_transposed=False) self._dest = dest self._src = src self._is_ready = True src.add_user(self) dest.add_user(self)
def __init__(self, vm: VM, src: Symbol, dest: Symbol, alpha: float, beta: float, num_threads: int): super(StoreRegToGlb, self).__init__(vm) if src.stype != SymbolType.Register: raise InternalError('store: operand `src` is not in reg mem') if not isinstance(src.obj, RegMemObject): raise InternalError(f'store: operand `src` is registers, instead: {type(src.obj)}') if dest.stype != SymbolType.Global: raise InternalError('store: operand `dest` is not in global memory.') if not isinstance(dest.obj, Matrix): raise InternalError('store: operand `dest` is not a matrix') src.add_user(self) dest.add_user(self) dest.data_view = DataView(rows=dest.obj.get_actual_num_rows(), columns=dest.obj.get_actual_num_cols(), lead_dim=dest.obj.num_rows, is_transposed=False) self._dest: Symbol = dest self._src: Symbol = src self._alpha = alpha self._beta = beta self._num_threads: int = num_threads self._is_ready: bool = True
def _check(self) -> None: if self._src.stype != SymbolType.Global: raise InternalError('shr-load: `src` operand is not in global mem.') if not isinstance(self._src.obj, Matrix): raise InternalError(f'shr-load: `src` operand is not a matrix, instead: {self._src.obj}') if self._dest.stype != SymbolType.SharedMem: raise InternalError('shr-load: `dest` operand is not in shr. mem.') if not isinstance(self._dest.obj, Matrix): raise InternalError(f'shr-load: `dest` operand is not a matrix, instead: {self._dest.obj}')
def shm_mem_loader_factory(vm, dest, src, shr_mem, num_threads, load_and_transpose=False): params = { 'vm': vm, 'dest': dest, 'src': src, 'shr_mem': shr_mem, 'num_threads': num_threads, 'load_and_transpose': load_and_transpose } if not isinstance(src.obj, Matrix): raise InternalError('shm-factory: `src` operand is not a matrix') # Use an extended loader if the tail of a active threads can touch the next column # Otherwise, use an exact one num_loads_per_column = ceil( src.obj.get_actual_num_rows() / num_threads) * num_threads if src.obj.num_rows > num_loads_per_column: if load_and_transpose: return ExactTransposePatchLoader(**params) else: return ExactPatchLoader(**params) else: if load_and_transpose: return ExtendedTransposePatchLoader(**params) else: return ExtendedPatchLoader(**params)
def __init__(self, vm: VM, src: Symbol): super(ClearRegisters, self).__init__(vm) if src.stype != SymbolType.Register: raise InternalError('ptr: operand `src` is not in registers') self._is_ready = True self._src = src src.add_user(self)
def _check(self): view_op1 = self._op1.data_view view_op2 = self._op2.data_view if not view_op1: raise InternalError( f'symbol data view has not been assign to `op1`') if not view_op1.is_transposed == self._trans_a: raise GenerationError( f'`op1 layout does not match the layout request by gemm instr.`' ) if not view_op2: raise InternalError( f'gemm: symbol data view has not been assign to `op2`') is_requested_layout = view_op2.is_transposed == self._trans_b # layout op1 is transposed if necessary and layout has already been adjusted # Note: if a subsequent GEMM requires to change the current layout # the matrix is going to be reloaded to the shared memory k_range_op1 = view_op1.columns # Note: we do not reload op2 to the shared memory if the current gemm op. requires # a different layout in contrast to the one that has already been loaded to the shared memory k_range_op2 = view_op2.rows if is_requested_layout else view_op2.columns if k_range_op1 != k_range_op2: print(view_op1) print(view_op2) raise GenerationError( f'gemm: mismatch of contraction length ' f'k_range_op1( {k_range_op1} ) != k_range_op2( {k_range_op2} )' ) if view_op2.columns > self._dest.obj.size: msg = f'{view_op2.columns} > {self._dest.obj.size}' raise InternalError( f'gemm: contraction length is bigger than reg. size i.e, {msg}' )
def __init__(self, vm: VM, src: Symbol, dest: Symbol, shr_mem: Symbol, num_threads: int): super(StoreRegToShr, self).__init__(vm) if src.stype != SymbolType.Register: raise InternalError('store: operand `src` is not in registers') if not isinstance(src.obj, RegMemObject): raise InternalError(f'store: operand `src` is not registers, instead: {type(src.obj)}') if dest.stype != SymbolType.SharedMem: raise InternalError('store: operand `dest` is not in shared mem.') if not isinstance(dest.obj, Matrix): raise InternalError(f'store: operand `dest` is not a matrix, instead: {type(src.obj)}') src.add_user(self) dest.add_user(self) shr_mem.add_user(self) dest.data_view = DataView(rows=dest.obj.get_actual_num_rows(), columns=dest.obj.get_actual_num_cols(), lead_dim=dest.obj.num_rows, is_transposed=False) self._dest: Symbol = dest self._src: Symbol = src self._shr_mem: Symbol = shr_mem self._num_threads: int = num_threads self._shr_mem_offset: Union[int, None] = None view: DataView = self._dest.data_view self._shm_volume: int = view.rows * view.columns