Example #1
0
    def lower(op):
        """
        Takes an op---either a Cast, Call, or a binary op (e.g. an Add) and returns a
        call to the specified external function, passing the op's argument
        or arguments. The return type of the call depends
        on the type of the op: if it is a custom type, then a uint of the same
        width as the custom type is returned. Otherwise, the type is
        unchanged."""
        dtype = op.dtype
        t = DataType(dtype)
        if get_type_registered(t.type_code):
            dtype = "uint" + str(t.bits)
            if t.lanes > 1:
                dtype += "x" + str(t.lanes)

        key = t.bits
        if isinstance(op, _Cast):
            src_bits = DataType(op.value.dtype).bits
            key = (src_bits, t.bits)

        if key not in extern_func_map:
            raise RuntimeError(
                f"missing key {key} in extern_func_map for {op.astext()}")

        if isinstance(op, _Cast):
            return call_pure_extern(dtype, extern_func_map[key], op.value)
        if isinstance(op, _FloatImm):
            return call_pure_extern(dtype, extern_func_map[key], op.value)
        if isinstance(op, _Call):
            return call_pure_extern(dtype, extern_func_map[key], *op.args)
        if isinstance(op, _BinaryOpExpr):
            return call_pure_extern(dtype, extern_func_map[key], op.a, op.b)

        raise RuntimeError(f"lowering unsupported op: {op.astext()}")
Example #2
0
 def __getitem__(self, index):
     t = DataType(self._content_type)
     index = self._linear_index(index)
     if t.lanes > 1:
         base = index * t.lanes
         index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
     return _expr.Load(self._content_type, self._buffer_var, index)
Example #3
0
 def __getitem__(self, index):
     t = DataType(self._content_type)
     index = self._linear_index(index)
     if t.lanes > 1:
         base = index * t.lanes
         stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype)
         index = _expr.Ramp(base, stride, t.lanes)
     return _expr.Load(self._content_type, self._buffer_var, index)
Example #4
0
 def __setitem__(self, index, value):
     value = _api.convert(value)
     if value.dtype != self._content_type:
         raise ValueError("data type does not match content type %s vs %s" %
                          (value.dtype, self._content_type))
     t = DataType(self._content_type)
     if t.lanes > 1:
         index = _make.Ramp(index * t.lanes, 1, t.lanes)
     self._builder.emit(_make.Store(self._buffer_var, value, index))
Example #5
0
 def __setitem__(self, index, value):
     value = convert(value)
     if value.dtype != self._content_type:
         raise ValueError("data type does not match content type %s vs %s" %
                          (value.dtype, self._content_type))
     t = DataType(self._content_type)
     if t.lanes > 1:
         base = index * t.lanes
         index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
     self._builder.emit(_stmt.Store(self._buffer_var, value, index))
Example #6
0
 def __setitem__(self, index, value):
     value = convert(value)
     if value.dtype != self._content_type:
         raise ValueError("data type does not match content type %s vs %s" %
                          (value.dtype, self._content_type))
     index = self._linear_index(index)
     t = DataType(self._content_type)
     if t.lanes > 1:
         base = index * t.lanes
         stride = 1 if (not hasattr(base, "dtype")) else const(
             1, base.dtype)
         index = _expr.Ramp(base, stride, t.lanes)
     self._builder.emit(_stmt.Store(self._buffer_var, value, index))
Example #7
0
 def lower(op):
     """
     Takes an op---either a Cast or a binary op (e.g. an Add) and returns a
     call to the specified external function, passing the op's argument
     (Cast) or arguments (a binary op). The return type of the call depends
     on the type of the op: if it is a custom type, then a uint of the same
     width as the custom type is returned. Otherwise, the type is
     unchanged."""
     dtype = op.dtype
     t = DataType(dtype)
     if get_type_registered(t.type_code):
         dtype = "uint" + str(t.bits)
         if t.lanes > 1:
             dtype += "x" + str(t.lanes)
     if isinstance(op, (_Cast, _FloatImm)):
         return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value)
     return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b)
Example #8
0
def _dtype_is_float(value):
    if isinstance(value, float):
        return True
    return (isinstance(value, ExprOp)
            and DataType(value.dtype).type_code == DataTypeCode.FLOAT)
Example #9
0
def _dtype_is_int(value):
    if isinstance(value, int):
        return True
    return (isinstance(value, ExprOp)
            and DataType(value.dtype).type_code == DataTypeCode.INT)
Example #10
0
 def __getitem__(self, index):
     t = DataType(self._content_type)
     if t.lanes > 1:
         index = _make.Ramp(index * t.lanes, 1, t.lanes)
     return _make.Load(self._content_type, self._buffer_var, index)
Example #11
0
def dtype_is_uint(value):
    return (isinstance(value, tir.expr.ExprOp)
            and DataType(value.dtype).type_code == TypeCode.UINT)