def _check2_inplace(a0, a1): """Validate the inputs of a binary generic array operation (in-place)""" s0, s1 = len(a0), len(a1) sr = max(s0, s1) if s0 != sr or (s1 != sr and s1 != 1): raise Exception("Incompatible argument sizes: %i and %i" % (s0, s1)) elif type(a0) is not type(a1): # noqa raise Exception("Type mismatch!") return sr
def _check2_mask(a0, a1): """Validate the inputs of a binary mask-producing array operation""" s0, s1 = len(a0), len(a1) sr = max(s0, s1) if (s0 != sr and s0 != 1) or (s1 != sr and s1 != 1): raise Exception("Incompatible argument sizes: %i and %i" % (s0, s1)) elif type(a0) is not type(a1): # noqa raise Exception("Type mismatch!") ar = a0.MaskType.empty_(sr if a0.Size == Dynamic else 0) return (ar, sr)
def atan2_(a0, a1): if not a0.IsFloat: raise Exception("atan2(): requires floating point operands!") ar, sr = _check2(a0, a1) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.atan2(a0[i], a1[i]) else: raise Exception("atan2(): unsupported array type!") return ar
def tgamma_(a0): if not a0.IsFloat: raise Exception("tgamma(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.tgamma(a0[i]) else: raise Exception("tgamma(): unsupported array type!") return ar
def rcp_(a0): if not a0.IsFloat: raise Exception("rcp(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.rcp(a0[i]) elif a0.IsComplex or a0.IsQuaternion: return _ek.conj(a0) * _ek.rcp(_ek.squared_norm(a0)) else: raise Exception('rcp(): unsupported array type!') return ar
def abs_(a0): if not a0.IsArithmetic: raise Exception("abs(): requires arithmetic operands!") if not a0.IsSpecial or a0.IsMatrix: ar, sr = _check1(a0) for i in range(sr): ar[i] = _ek.abs(a0[i]) return ar elif a0.IsSpecial: return _ek.norm(a0) else: raise Exception('abs(): unsupported array type!')
def _check3_select(a0, a1, a2): """Validate the inputs of a select() array operation""" s0, s1, s2 = len(a0), len(a1), len(a2) sr = max(s0, s1, s2) if (s0 != sr and s0 != 1) or (s1 != sr and s1 != 1) or \ (s2 != sr and s2 != 1): raise Exception("Incompatible argument sizes: %i, %i, and %i" % (s0, s1, s2)) elif type(a1) is not type(a2) or type(a0) is not type(a1).MaskType: # noqa raise Exception("Type mismatch!") ar = a1.empty_(sr if a0.Size == Dynamic else 0) return (ar, sr)
def atanh_(a0): if not a0.IsFloat: raise Exception("atanh(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.atanh(a0[i]) elif a0.IsComplex: return _ek.log((1 + a0) / (1 - a0)) * .5 else: raise Exception("atanh(): unsupported array type!") return ar
def acosh_(a0): if not a0.IsFloat: raise Exception("acosh(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.acosh(a0[i]) elif a0.IsComplex: return 2 * _ek.log(_ek.sqrt(.5 * (a0 + 1)) + _ek.sqrt(.5 * (a0 - 1))) else: raise Exception("acosh(): unsupported array type!") return ar
def log2_(a0): if not a0.IsFloat: raise Exception("log2(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.log2(a0[i]) elif a0.IsComplex: ar.real = .5 * _ek.log2(_ek.squared_norm(a0)) ar.imag = _ek.arg(a0) * _ek.InvLogTwo else: raise Exception("log2(): unsupported array type!") return ar
def truediv_(a0, a1): if not a0.IsFloat: raise Exception("Use the floor division operator \"//\" for " "Enoki integer arrays.") if not a0.IsSpecial: ar, sr = _check2(a0, a1) for i in range(sr): ar[i] = a0[i] / a1[i] return ar elif a0.IsSpecial: return a0 * a1.rcp_() else: raise Exception("truediv(): unsupported array type!")
def cot_(a0): if not a0.IsFloat: raise Exception("cot(): requires floating point operands!") if not a0.IsSpecial: ar, sr = _check1(a0) for i in range(sr): ar[i] = _ek.cot(a0[i]) elif a0.IsComplex: s, c = _ek.sincos(a0) return c / s else: raise Exception("cot(): unsupported array type!") return ar
def asin_(a0): if not a0.IsFloat: raise Exception("asin(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.asin(a0[i]) elif a0.IsSpecial: tmp = _ek.log(type(a0)(-a0.imag, a0.real) + _ek.sqrt(1 - _ek.sqr(a0))) ar.real = tmp.imag ar.imag = -tmp.real else: raise Exception("asin(): unsupported array type!") return ar
def atan_(a0): if not a0.IsFloat: raise Exception("atan(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.atan(a0[i]) elif a0.IsSpecial: im = type(a0)(0, 1) tmp = _ek.log((im - a0) / (im + a0)) return type(a0)(tmp.imag * .5, -tmp.real * 0.5) else: raise Exception("atan(): unsupported array type!") return ar
def exp2_(a0): if not a0.IsFloat: raise Exception("exp2(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.exp2(a0[i]) elif a0.IsComplex: s, c = _ek.sincos(a0.imag * _ek.LogTwo) exp_r = _ek.exp2(a0.real) ar.real = exp_r * c ar.imag = exp_r * s else: raise Exception("exp2(): unsupported array type!") return ar
def cosh_(a0): if not a0.IsFloat: raise Exception("cosh(): requires floating point operands!") ar, sr = _check1(a0) if not a0.IsSpecial: for i in range(sr): ar[i] = _ek.cosh(a0[i]) elif a0.IsComplex: s, c = _ek.sincos(a0.imag) sh, ch = _ek.sincosh(a0.real) ar.real = ch * c ar.imag = sh * s else: raise Exception("cosh(): unsupported array type!") return ar
def isr_(a0, a1): if not a0.IsIntegral: raise Exception("isr(): requires integral operands!") sr = _check2_inplace(a0, a1) for i in range(sr): a0[i] >>= a1[i] return a0
def accum_grad_(a, grad): if not a.IsDiff: raise Exception("Expected a differentiable array type!") s = len(a) for i in range(s): a[i].accum_grad_(grad[i])
def neg_(a0): if not a0.IsArithmetic: raise Exception("neg(): requires arithmetic operands!") ar, sr = _check1(a0) for i in range(sr): ar[i] = -a0[i] return ar
def isub_(a0, a1): if not a0.IsArithmetic: raise Exception("isub(): requires arithmetic operands!") sr = _check2_inplace(a0, a1) for i in range(sr): a0[i] -= a1[i] return a0
def sub_(a0, a1): if not a0.IsArithmetic: raise Exception("sub(): requires arithmetic operands!") ar, sr = _check2(a0, a1) for i in range(sr): ar[i] = a0[i] - a1[i] return ar
def popcnt_(a0): if not a0.IsIntegral: raise Exception("popcnt(): requires integral operands!") ar, sr = _check1(a0) for i in range(sr): ar[i] = _ek.popcnt(a0[i]) return ar
def matmul_(a0, a1): if not (a0.Size == a1.Size and (a0.IsMatrix or a0.IsVector) \ and (a1.IsMatrix or a1.IsVector)): raise Exception("matmul(): unsupported operand shape!") if a0.IsVector and a1.IsVector: return _ek.dot(a0, a1) elif a0.IsMatrix and a1.IsVector: ar = a0[0] * a1[0] for i in range(1, a1.Size): ar = _ek.fmadd(a0[i], a1[i], ar) return ar elif a0.IsVector and a1.IsMatrix: ar = a1.Value() for i in range(a1.Size): ar[i] = _ek.dot(a0, a1[i]) return ar else: # matrix @ matrix ar, sr = _check2(a0, a1) for j in range(a0.Size): accum = a0[0] * _ek.full(a0.Value, a1[0, j]) for i in range(1, a0.Size): accum = _ek.fmadd(a0[i], _ek.full(a0.Value, a1[i, j]), accum) ar[j] = accum return ar
def min_(a0, a1): if not a0.IsArithmetic: raise Exception("min(): requires arithmetic operands!") ar, sr = _check2(a0, a1) for i in range(sr): ar[i] = _ek.min(a0[i], a1[i]) return ar
def mod_(a0, a1): if not a0.IsIntegral: raise Exception("mod(): requires arithmetic operands!") ar, sr = _check2(a0, a1) for i in range(sr): ar[i] = a0[i] % a1[i] return ar
def mulhi_(a0, a1): ar, sr = _check2(a0, a1) if not a0.IsIntegral: raise Exception("mulhi(): requires integral operands!") for i in range(sr): ar[i] = _ek.mulhi(a0[i], a1[i]) return ar
def imod_(a0, a1): if not a0.IsIntegral: raise Exception("imod(): requires arithmetic operands!") sr = _check2_inplace(a0, a1) for i in range(sr): a0[i] %= a1[i] return a0
def sr_(a0, a1): if not a0.IsIntegral: raise Exception("sr(): requires integral operands!") ar, sr = _check2(a0, a1) for i in range(sr): ar[i] = a0[i] >> a1[i] return ar
def ge_(a0, a1): if not a0.IsArithmetic: raise Exception("ge(): requires arithmetic operands!") ar, sr = _check2_mask(a0, a1) for i in range(sr): ar[i] = a0[i] >= a1[i] return ar
def trunc_(a0): if not a0.IsArithmetic: raise Exception("trunc(): requires arithmetic operands!") ar, sr = _check1(a0) for i in range(sr): ar[i] = _ek.trunc(a0[i]) return ar