Exemple #1
0
def broadcast_to(a, shape):
    """Broadcasts an array to the desired shape if possible.

  Args:
    a: array_like
    shape: a scalar or 1-d tuple/list.

  Returns:
    An ndarray.
  """
    return array_creation.full(shape, a)
Exemple #2
0
def _setitem(arr, index, value):
  """Sets the `value` at `index` in the array `arr`.

  This works by replacing the slice at `index` in the tensor with `value`.
  Since tensors are immutable, this builds a new tensor using the `tf.concat`
  op. Currently, only 0-d and 1-d indices are supported.

  Note that this may break gradients e.g.

  a = tf_np.array([1, 2, 3])
  old_a_t = a.data

  with tf.GradientTape(persistent=True) as g:
    g.watch(a.data)
    b = a * 2
    a[0] = 5
  g.gradient(b.data, [a.data])  # [None]
  g.gradient(b.data, [old_a_t])  # [[2., 2., 2.]]

  Here `d_b / d_a` is `[None]` since a.data no longer points to the same
  tensor.

  Args:
    arr: array_like.
    index: scalar or 1-d integer array.
    value: value to set at index.

  Returns:
    ndarray

  Raises:
    ValueError: if `index` is not a scalar or 1-d array.
  """
  # TODO(srbs): Figure out a solution to the gradient problem.
  arr = array_creation.asarray(arr)
  index = array_creation.asarray(index)
  if index.ndim == 0:
    index = ravel(index)
  elif index.ndim > 1:
    raise ValueError('index must be a scalar or a 1-d array.')
  value = array_creation.asarray(value, dtype=arr.dtype)
  if arr.shape[len(index):] != value.shape:
    value = array_creation.full(arr.shape[len(index):], value)
  prefix_t = arr.data[:index.data[0]]
  postfix_t = arr.data[index.data[0] + 1:]
  if len(index) == 1:
    arr._data = tf.concat(  # pylint: disable=protected-access
        [prefix_t, tf.expand_dims(value.data, 0), postfix_t], 0)
  else:
    subarray = arr[index.data[0]]
    _setitem(subarray, index[1:], value)
    arr._data = tf.concat(  # pylint: disable=protected-access
        [prefix_t, tf.expand_dims(subarray.data, 0), postfix_t], 0)
 def testFull(self):
     # List of 2-tuples of fill value and shape.
     data = [
         (5, ()),
         (5, (7, )),
         (5., (7, )),
         ([5, 8], (2, )),
         ([5, 8], (3, 2)),
         ([[5], [8]], (2, 3)),
         ([[5], [8]], (3, 2, 5)),
         ([[5.], [8.]], (3, 2, 5)),
         ([[3, 4], [5, 6], [7, 8]], (3, 3, 2)),
     ]
     for f, s in data:
         for fn1, fn2 in itertools.product(self.array_transforms,
                                           self.shape_transforms):
             fill_value = fn1(f)
             shape = fn2(s)
             self.match(array_creation.full(shape, fill_value),
                        np.full(shape, fill_value))
             for dtype in self.all_types:
                 self.match(
                     array_creation.full(shape, fill_value, dtype=dtype),
                     np.full(shape, fill_value, dtype=dtype))