def _broadcast_to(arr, shape): if hasattr(arr, "broadcast_to"): return arr.broadcast_to(shape) _check_arraylike("broadcast_to", arr) arr = arr if isinstance(arr, ndarray) else _asarray(arr) if not isinstance(shape, tuple) and np.ndim(shape) == 0: shape = (shape, ) shape = core.canonicalize_shape(shape) # check that shape is concrete arr_shape = np.shape(arr) if core.symbolic_equal_shape(arr_shape, shape): return arr else: nlead = len(shape) - len(arr_shape) shape_tail = shape[nlead:] compatible = all( core.symbolic_equal_one_of_dim(arr_d, [1, shape_d]) for arr_d, shape_d in safe_zip(arr_shape, shape_tail)) if nlead < 0 or not compatible: msg = "Incompatible shapes for broadcasting: {} and requested shape {}" raise ValueError(msg.format(arr_shape, shape)) diff, = np.where( tuple(not core.symbolic_equal_dim(arr_d, shape_d) for arr_d, shape_d in safe_zip(arr_shape, shape_tail))) new_dims = tuple(range(nlead)) + tuple(nlead + diff) kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims)) return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims)
def _approx_top_k_batch_rule(batched_args, batch_dims, *, k, reduction_dimension, recall_target, is_max_k, reduction_input_size_override, aggregate_to_topk): prototype_arg, new_bdim = next( (a, b) for a, b in zip(batched_args, batch_dims) if b is not None) new_args = [] for arg, bdim in zip(batched_args, batch_dims): if bdim is None: dims = np.delete(np.arange(prototype_arg.ndim), new_bdim) new_args.append(lax.broadcast_in_dim(arg, prototype_arg.shape, dims)) else: new_args.append(batching.moveaxis(arg, bdim, new_bdim)) new_reduction_dim = reduction_dimension + (new_bdim <= reduction_dimension) bdims = (new_bdim,) * len(new_args) return (approx_top_k_p.bind( *new_args, k=k, reduction_dimension=new_reduction_dim, recall_target=recall_target, is_max_k=False, reduction_input_size_override=reduction_input_size_override, aggregate_to_topk=aggregate_to_topk), bdims)
def _bcast_select_n(pred, *cases): if np.ndim(pred) != np.ndim(cases[0]): idx = list(range(np.ndim(pred))) pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) return lax.select_n(pred, *cases)
def _bcast_select(pred, on_true, on_false): if np.ndim(pred) != np.ndim(on_true): idx = list(range(np.ndim(pred))) pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx) return lax.select(pred, on_true, on_false)