コード例 #1
0
def _broadcast_to(arr, shape):
    if hasattr(arr, "broadcast_to"):
        return arr.broadcast_to(shape)
    _check_arraylike("broadcast_to", arr)
    arr = arr if isinstance(arr, ndarray) else _asarray(arr)
    if not isinstance(shape, tuple) and np.ndim(shape) == 0:
        shape = (shape, )
    shape = core.canonicalize_shape(shape)  # check that shape is concrete
    arr_shape = np.shape(arr)
    if core.symbolic_equal_shape(arr_shape, shape):
        return arr
    else:
        nlead = len(shape) - len(arr_shape)
        shape_tail = shape[nlead:]
        compatible = all(
            core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
            for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
        if nlead < 0 or not compatible:
            msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
            raise ValueError(msg.format(arr_shape, shape))
        diff, = np.where(
            tuple(not core.symbolic_equal_dim(arr_d, shape_d)
                  for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
        new_dims = tuple(range(nlead)) + tuple(nlead + diff)
        kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
        return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape,
                                    kept_dims)
コード例 #2
0
ファイル: util.py プロジェクト: xueeinstein/jax
def _broadcast_arrays(*args):
  """Like Numpy's broadcast_arrays but doesn't return views."""
  shapes = [np.shape(arg) for arg in args]
  if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
    # TODO(mattjj): remove the array(arg) here
    return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg)
            for arg in args]
  result_shape = lax.broadcast_shapes(*shapes)
  return [_broadcast_to(arg, result_shape) for arg in args]
コード例 #3
0
ファイル: impl_no_xla.py プロジェクト: matthewfeickert/jax
def _pre_gather_for_multidim_indexing(args: GatherArgs):
    """Returns True if this call to gather represents multi-dimensional indexing.

  E.g., jnp.take(op, [[0], [1]], axis=0).
  Note we currently only support multi-dimensional indexing if the last
  dimension is 1.
  """
    # Handle only the case when tf.gather argument batch_dims=0.
    # Find axis to match the tf.gather semantics
    # Let I = len(start_indices_shape)
    # let O = len(op_shape)
    # slice_sizes == op_shape[:axis] + (1,) + op_shape[axis+1:]
    # collapsed_slice_dims == (axis,)
    # start_index_map == (axis,)
    # offset_dims == (0, 1, ..., axis - 1, axis + I, ..., O + I - 1)
    # We added a trailing dimension of size 1
    op_shape = args.op_shape
    start_index_map = args.dnums.start_index_map
    collapsed_slice_dims = args.dnums.collapsed_slice_dims
    offset_dims = args.dnums.offset_dims
    if not (len(op_shape) >= 1 and len(start_index_map) == 1
            and len(collapsed_slice_dims) == 1 and collapsed_slice_dims[0]
            == start_index_map[0] and len(offset_dims) == len(op_shape) - 1):
        raise ValueError("unsupported dimension numbers")
    # We added a trailing dimension of size 1
    if not core.symbolic_equal_dim(args.start_indices_shape[-1], 1):
        raise ValueError("start_indices shape[-1] should be 1")
    # Guess the axis
    axis = collapsed_slice_dims[0]
    index_dims = len(args.start_indices_shape) - 1
    expected_offset_dims = tuple(
        list(range(axis)) +
        list(range(axis + index_dims,
                   len(op_shape) + index_dims - 1)))
    if offset_dims != expected_offset_dims:
        raise ValueError("unsupported offset_dims")
    expected_slice_sizes = op_shape[:axis] + (
        1, ) + op_shape[axis + 1:]  # type: ignore
    if not core.symbolic_equal_shape(args.slice_sizes, expected_slice_sizes):
        raise ValueError("unsupported slice_sizes")
コード例 #4
0
 def join(self, other):
     assert core.symbolic_equal_shape(self.shape, other.shape)
     assert self.dtype == other.dtype
     return self