def lower(op): """ Takes an op---either a Cast, Call, or a binary op (e.g. an Add) and returns a call to the specified external function, passing the op's argument or arguments. The return type of the call depends on the type of the op: if it is a custom type, then a uint of the same width as the custom type is returned. Otherwise, the type is unchanged.""" dtype = op.dtype t = DataType(dtype) if get_type_registered(t.type_code): dtype = "uint" + str(t.bits) if t.lanes > 1: dtype += "x" + str(t.lanes) key = t.bits if isinstance(op, _Cast): src_bits = DataType(op.value.dtype).bits key = (src_bits, t.bits) if key not in extern_func_map: raise RuntimeError( f"missing key {key} in extern_func_map for {op.astext()}") if isinstance(op, _Cast): return call_pure_extern(dtype, extern_func_map[key], op.value) if isinstance(op, _FloatImm): return call_pure_extern(dtype, extern_func_map[key], op.value) if isinstance(op, _Call): return call_pure_extern(dtype, extern_func_map[key], *op.args) if isinstance(op, _BinaryOpExpr): return call_pure_extern(dtype, extern_func_map[key], op.a, op.b) raise RuntimeError(f"lowering unsupported op: {op.astext()}")
def __getitem__(self, index): t = DataType(self._content_type) index = self._linear_index(index) if t.lanes > 1: base = index * t.lanes index = _expr.Ramp(base, const(1, base.dtype), t.lanes) return _expr.Load(self._content_type, self._buffer_var, index)
def __getitem__(self, index): t = DataType(self._content_type) index = self._linear_index(index) if t.lanes > 1: base = index * t.lanes stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) index = _expr.Ramp(base, stride, t.lanes) return _expr.Load(self._content_type, self._buffer_var, index)
def __setitem__(self, index, value): value = _api.convert(value) if value.dtype != self._content_type: raise ValueError("data type does not match content type %s vs %s" % (value.dtype, self._content_type)) t = DataType(self._content_type) if t.lanes > 1: index = _make.Ramp(index * t.lanes, 1, t.lanes) self._builder.emit(_make.Store(self._buffer_var, value, index))
def __setitem__(self, index, value): value = convert(value) if value.dtype != self._content_type: raise ValueError("data type does not match content type %s vs %s" % (value.dtype, self._content_type)) t = DataType(self._content_type) if t.lanes > 1: base = index * t.lanes index = _expr.Ramp(base, const(1, base.dtype), t.lanes) self._builder.emit(_stmt.Store(self._buffer_var, value, index))
def __setitem__(self, index, value): value = convert(value) if value.dtype != self._content_type: raise ValueError("data type does not match content type %s vs %s" % (value.dtype, self._content_type)) index = self._linear_index(index) t = DataType(self._content_type) if t.lanes > 1: base = index * t.lanes stride = 1 if (not hasattr(base, "dtype")) else const( 1, base.dtype) index = _expr.Ramp(base, stride, t.lanes) self._builder.emit(_stmt.Store(self._buffer_var, value, index))
def lower(op): """ Takes an op---either a Cast or a binary op (e.g. an Add) and returns a call to the specified external function, passing the op's argument (Cast) or arguments (a binary op). The return type of the call depends on the type of the op: if it is a custom type, then a uint of the same width as the custom type is returned. Otherwise, the type is unchanged.""" dtype = op.dtype t = DataType(dtype) if get_type_registered(t.type_code): dtype = "uint" + str(t.bits) if t.lanes > 1: dtype += "x" + str(t.lanes) if isinstance(op, (_Cast, _FloatImm)): return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value) return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b)
def _dtype_is_float(value): if isinstance(value, float): return True return (isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.FLOAT)
def _dtype_is_int(value): if isinstance(value, int): return True return (isinstance(value, ExprOp) and DataType(value.dtype).type_code == DataTypeCode.INT)
def __getitem__(self, index): t = DataType(self._content_type) if t.lanes > 1: index = _make.Ramp(index * t.lanes, 1, t.lanes) return _make.Load(self._content_type, self._buffer_var, index)
def dtype_is_uint(value): return (isinstance(value, tir.expr.ExprOp) and DataType(value.dtype).type_code == TypeCode.UINT)