def _broadcast_to(arr, shape): if hasattr(arr, "broadcast_to"): return arr.broadcast_to(shape) _check_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, ndarray) else _asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape, ) shape = core.canonicalize_shape(shape) # check that shape is concrete arr_shape = np.shape(arr) if core.symbolic_equal_shape(arr_shape, shape): return arr else: nlead = len(shape) - len(arr_shape) shape_tail = shape[nlead:] compatible = all( core.symbolic_equal_one_of_dim(arr_d, [1, shape_d]) for arr_d, shape_d in safe_zip(arr_shape, shape_tail)) if nlead < 0 or not compatible: msg = "Incompatible shapes for broadcasting: {} and requested shape {}" raise ValueError(msg.format(arr_shape, shape)) diff, = np.where( tuple(not core.symbolic_equal_dim(arr_d, shape_d) for arr_d, shape_d in safe_zip(arr_shape, shape_tail))) new_dims = tuple(range(nlead)) + tuple(nlead + diff) kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims)) return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
def _broadcast_arrays(*args): """Like Numpy's broadcast_arrays but doesn't return views.""" shapes = [np.shape(arg) for arg in args] if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes): # TODO(mattjj): remove the array(arg) here return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg) for arg in args] result_shape = lax.broadcast_shapes(*shapes) return [_broadcast_to(arg, result_shape) for arg in args]
def _pre_gather_for_multidim_indexing(args: GatherArgs): """Returns True if this call to gather represents multi-dimensional indexing. E.g., jnp.take(op, [[0], [1]], axis=0). Note we currently only support multi-dimensional indexing if the last dimension is 1. """ # Handle only the case when tf.gather argument batch_dims=0. # Find axis to match the tf.gather semantics # Let I = len(start_indices_shape) # let O = len(op_shape) # slice_sizes == op_shape[:axis] + (1,) + op_shape[axis+1:] # collapsed_slice_dims == (axis,) # start_index_map == (axis,) # offset_dims == (0, 1, ..., axis - 1, axis + I, ..., O + I - 1) # We added a trailing dimension of size 1 op_shape = args.op_shape start_index_map = args.dnums.start_index_map collapsed_slice_dims = args.dnums.collapsed_slice_dims offset_dims = args.dnums.offset_dims if not (len(op_shape) >= 1 and len(start_index_map) == 1 and len(collapsed_slice_dims) == 1 and collapsed_slice_dims[0] == start_index_map[0] and len(offset_dims) == len(op_shape) - 1): raise ValueError("unsupported dimension numbers") # We added a trailing dimension of size 1 if not core.symbolic_equal_dim(args.start_indices_shape[-1], 1): raise ValueError("start_indices shape[-1] should be 1") # Guess the axis axis = collapsed_slice_dims[0] index_dims = len(args.start_indices_shape) - 1 expected_offset_dims = tuple( list(range(axis)) + list(range(axis + index_dims, len(op_shape) + index_dims - 1))) if offset_dims != expected_offset_dims: raise ValueError("unsupported offset_dims") expected_slice_sizes = op_shape[:axis] + ( 1, ) + op_shape[axis + 1:] # type: ignore if not core.symbolic_equal_shape(args.slice_sizes, expected_slice_sizes): raise ValueError("unsupported slice_sizes")
def join(self, other): assert core.symbolic_equal_shape(self.shape, other.shape) assert self.dtype == other.dtype return self