def diff(self, d1: DimSize, d2: DimSize) -> DimSize: if self.symbolic_equal(d1, d2): return 0 if d2 in {0}: return d1 raise core.InconclusiveDimensionOperation( f"Subtracting shape variables is not supported ({d1} - {d2})")
def threefry_2x32(keypair, count): """Apply the Threefry 2x32 hash. Args: keypair: a pair of 32bit unsigned integers used for the key. count: an array of dtype uint32 used for the counts. Returns: An array of dtype uint32 with the same shape as `count`. """ key1, key2 = keypair if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == np.uint32: msg = "threefry_2x32 requires uint32 arguments, got {}" raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]])) try: odd_size = count.size % 2 except core.InconclusiveDimensionOperation as e: msg = ( "jax.random functions have limited support for shape polymorphism. " "In particular, the product of the known dimensions must be even.") raise core.InconclusiveDimensionOperation(msg) from e if odd_size: x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2)) else: x = list(jnp.split(count.ravel(), 2)) x = threefry2x32_p.bind(key1, key2, x[0], x[1]) out = jnp.concatenate(x) assert out.dtype == np.uint32 return lax.reshape(out[:-1] if odd_size else out, count.shape)
def dilate(self, d: DimSize, dilation: DimSize) -> DimSize: """Implements `0 if d == 0 else 1 + dilation * (d - 1))`""" if dilation not in {1}: raise core.InconclusiveDimensionOperation( f"Only dilation == 1 is supported for shape variables (var = {d}, " f"dilation = {dilation})") return d
def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: """Implements `(d - window_size) // window_stride + 1`""" if {window_size, window_stride} != {1}: raise core.InconclusiveDimensionOperation( "Only striding with window_size == window_stride == 1 is supported " f"for shape variables (var = {d}, window_size = {window_size}, " f"stride = {window_stride}") return d
def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: """Implements `(d - window_size) // window_stride + 1`""" if {window_size, window_stride} != {1}: raise core.InconclusiveDimensionOperation( f"Striding is not supported for shape variables (window_size = {window_size}, stride = {window_stride}" ) return d
def ge(self, other: DimSize) -> bool: lb, ub = _ensure_poly(self - other).bounds() if lb is not None and lb >= 0: return True if ub is not None and ub < 0: return False raise core.InconclusiveDimensionOperation( f"Dimension polynomial comparison '{self}' >= '{other}' is inconclusive" )
def divide_shape_sizes(self, s1: Shape, s2: Shape) -> int: s1_ints, s1_vars = _split_shape_ints(s1) s2_ints, s2_vars = _split_shape_ints(s2) if collections.Counter(s1_vars) != collections.Counter(s2_vars): msg = ( f"Shapes {s1} and {s2} must have the same set of shape variables." ) raise core.InconclusiveDimensionOperation(msg) return super(DimensionHandlerVar, self).divide_shape_sizes(s1_ints, s2_ints)
def divmod(self, divisor: DimSize) -> Tuple[DimSize, int]: """ Floor division with remainder (divmod) generalized to polynomials. If the `divisor` is not a constant, the remainder must be 0. If the `divisor` is a constant, the remainder may be non 0, for consistency with integer divmod. :return: Quotient resulting from polynomial division and integer remainder. """ divisor = _ensure_poly(divisor) dmon, dcount = divisor.leading_term dividend, quotient = self, 0 err_msg = f"Dimension polynomial '{self}' is not a multiple of '{divisor}'" # invariant: self = dividend + divisor * quotient # the leading term of dividend decreases through the loop. while not (isinstance(dividend, int) or dividend.is_constant): mon, count = dividend.leading_term try: qmon = mon.divide(dmon) except core.InconclusiveDimensionOperation: raise core.InconclusiveDimensionOperation(err_msg) qcount, rcount = divmod(count, dcount) if rcount != 0: raise core.InconclusiveDimensionOperation(err_msg) q = _DimPolynomial.from_coeffs({qmon: qcount}) quotient += q dividend -= q * divisor # type: ignore[assignment] dividend = int(dividend) # type: ignore[assignment] if divisor.is_constant: q, r = divmod(dividend, int(divisor)) # type: ignore quotient += q remainder = r else: if dividend != 0: raise core.InconclusiveDimensionOperation(err_msg) remainder = 0 if config.jax_enable_checks: assert self == divisor * quotient + remainder return quotient, remainder
def stride(self, d: DimSize, window_size: DimSize, window_stride: DimSize) -> DimSize: """Implements `(d - window_size) // window_stride + 1`""" try: q, r = _ensure_poly(d - window_size).divmod(window_stride) return q + 1 except core.InconclusiveDimensionOperation as e: raise core.InconclusiveDimensionOperation( f"Cannot compute stride for dimension '{d}', " f"window_size '{window_size}', stride '{window_stride}'. Reason: {e}." ) return d
def divide(self, divisor: '_DimMon') -> '_DimMon': """ Divides by another monomial. Raises a core.InconclusiveDimensionOperation if the result is not a monomial. For example, (n^3 * m) // n == n^2*m, but n // m fails. """ d = collections.Counter(self) for key, exponent in divisor.items(): diff = self.get(key, 0) - exponent if diff < 0: raise core.InconclusiveDimensionOperation( f"Cannot divide {self} by {divisor}.") elif diff == 0: del d[key] elif diff > 0: d[key] = diff return _DimMon(d)
def __int__(self): if self.is_constant: return op.index(next(iter(self.values()))) else: raise core.InconclusiveDimensionOperation( f"Dimension polynomial '{self}' is not constant")
def greater_equal(self, d1: DimSize, d2: DimSize): if self.symbolic_equal(d1, d2) or (type(d2) is not DimVar and 1 >= d2): return True else: raise core.InconclusiveDimensionOperation( f"Shape variable comparison {d1} >= {d2} is inconclusive")
def __eq__(self, other): if isinstance(other, DimVar) and self._varname == other._varname: return True else: raise core.InconclusiveDimensionOperation( f"Shape variable comparison {self} == {other} is inconclusive")