def __init__(self, trace, val): self._trace = trace if not isinstance(val, (tf.Tensor, tf.Variable)): aval = xla.abstractify(val) val = tf.convert_to_tensor(np.array(val, aval.dtype), dtype=aval.dtype) self.val = val
def _ndarray_constant_handler(val: np.ndarray, canonicalize_types) -> Sequence[ir.Value]: """Constant handler for ndarray literals, handling zero-size strides. In most cases this function calls _numpy_array_constant(val) except it has special handling of arrays with any strides of size zero: for those, it generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose to avoid staging in large literals that might arise from np.zeros or np.ones or the output of lax.broadcast (which uses np.broadcast_to which in turn uses size-zero strides). Args: val: an ndarray. Returns: An XLA ComputationDataHandle / XlaOp representing the constant ndarray staged into the XLA Computation. """ if dtypes.result_type(val) == dtypes.float0: return _numpy_array_constant(np.zeros(val.shape, dtype=np.bool_), canonicalize_types=False) elif np.any(np.equal(0, val.strides)) and val.size > 0: zero_stride_axes, = np.where(np.equal(0, val.strides)) other_axes, = np.where(np.not_equal(0, val.strides)) collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) for ax in range(val.ndim))] out = mhlo.BroadcastInDimOp( aval_to_ir_type(xla.abstractify(val)), _numpy_array_constant(collapsed_val, canonicalize_types)[0], dense_int_elements(other_axes)).result return (out, ) else: return _numpy_array_constant(val, canonicalize_types)
def __init__(self, trace: 'TensorFlowTrace', val: TfValOrUnit): self._trace = trace if not (val is core.unit or isinstance(val, (tf.Tensor, tf.Variable))): aval = xla.abstractify(val) val = tf.convert_to_tensor(np.array(val, aval.dtype), dtype=aval.dtype) # type: ignore self.val = val
def _device_put_impl(x, device: Optional[Device] = None): if device_array.type_is_device_array(x): return _copy_device_array_to_device(x, device) try: a = xla.abstractify(x) except TypeError as err: raise TypeError( f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err return aval_to_result_handler(device, a)(*device_put(x, device))
def _numpy_array_constant(x: np.ndarray, canonicalize_types) -> Sequence[ir.Value]: if canonicalize_types: x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype)) aval = xla.abstractify(x) ir_type = aval_to_ir_type(aval) if x.dtype == np.bool_: x = np.packbits(x, bitorder='little') elif x.dtype == dtypes.bfloat16: x = x.view(np.uint16) x = np.ascontiguousarray(x) attr = ir.DenseElementsAttr.get(x, type=ir_type.element_type, shape=aval.shape) return (mhlo.ConstOp(ir_type, attr).result, )
def __init__(self, trace: 'TensorFlowTrace', val: TfValOrUnit): self._trace = trace if val is core.unit: self.val = val elif isinstance(val, (tf.Tensor, tf.Variable)): aval: core.ShapedArray = abstractify(val) if np.dtype(aval.dtype) != val.dtype.as_numpy_dtype: # type: ignore # This is expected when JAX 64-bit is not enabled self.val = tf.cast(val, dtype=aval.dtype) else: self.val = val else: # Must be a numeric value assert core.skip_checks or _is_tfval(val), f"Non TfVal: {val}" aval = xla.abstractify(val) # type: ignore self.val = tf.convert_to_tensor(np.array(val, aval.dtype), dtype=aval.dtype) # type: ignore assert core.skip_checks or aval.strip_weak_type() == self.aval.strip_weak_type(), ( f"Expected {aval}, got {self.aval}")
def arg_spec(x: Any) -> ArgSpec: aval = xla.abstractify(x) try: return aval, x._device except: return aval, None
def aval(v: core.Var) -> core.AbstractValue: if type(v) is core.Literal: return xla.abstractify(v.val) else: return v.aval
def aval(v): if type(v) is core.Literal: return xla.abstractify(v.val) else: return v.aval
def aval(v): return xla.abstractify(v.val) if type(v) is core.Literal else v.aval