Exemple #1
0
def _broadcast_to(arr, shape):
    if hasattr(arr, "broadcast_to"):
        return arr.broadcast_to(shape)
    _check_arraylike("broadcast_to", arr)
    arr = arr if isinstance(arr, ndarray) else _asarray(arr)
    if not isinstance(shape, tuple) and np.ndim(shape) == 0:
        shape = (shape, )
    shape = core.canonicalize_shape(shape)  # check that shape is concrete
    arr_shape = np.shape(arr)
    if core.symbolic_equal_shape(arr_shape, shape):
        return arr
    else:
        nlead = len(shape) - len(arr_shape)
        shape_tail = shape[nlead:]
        compatible = all(
            core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
            for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
        if nlead < 0 or not compatible:
            msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
            raise ValueError(msg.format(arr_shape, shape))
        diff, = np.where(
            tuple(not core.symbolic_equal_dim(arr_d, shape_d)
                  for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
        new_dims = tuple(range(nlead)) + tuple(nlead + diff)
        kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
        return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape,
                                    kept_dims)
Exemple #2
0
def _approx_top_k_batch_rule(batched_args, batch_dims, *, k,
                             reduction_dimension, recall_target, is_max_k,
                             reduction_input_size_override, aggregate_to_topk):
  prototype_arg, new_bdim = next(
      (a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
  new_args = []
  for arg, bdim in zip(batched_args, batch_dims):
    if bdim is None:
      dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
      new_args.append(lax.broadcast_in_dim(arg, prototype_arg.shape, dims))
    else:
      new_args.append(batching.moveaxis(arg, bdim, new_bdim))
  new_reduction_dim = reduction_dimension + (new_bdim <= reduction_dimension)
  bdims = (new_bdim,) * len(new_args)
  return (approx_top_k_p.bind(
      *new_args,
      k=k,
      reduction_dimension=new_reduction_dim,
      recall_target=recall_target,
      is_max_k=False,
      reduction_input_size_override=reduction_input_size_override,
      aggregate_to_topk=aggregate_to_topk), bdims)
Exemple #3
0
def _bcast_select_n(pred, *cases):
    if np.ndim(pred) != np.ndim(cases[0]):
        idx = list(range(np.ndim(pred)))
        pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx)
    return lax.select_n(pred, *cases)
Exemple #4
0
def _bcast_select(pred, on_true, on_false):
    if np.ndim(pred) != np.ndim(on_true):
        idx = list(range(np.ndim(pred)))
        pred = lax.broadcast_in_dim(pred, np.shape(on_true), idx)
    return lax.select(pred, on_true, on_false)