コード例 #1
0
ファイル: array_methods.py プロジェクト: yasyaindra/trax
def _setitem(arr, index, value):
    """Sets the `value` at `index` in the array `arr`.

  This works by replacing the slice at `index` in the tensor with `value`.
  Since tensors are immutable, this builds a new tensor using the `tf.concat`
  op. Currently, only 0-d and 1-d indices are supported.

  Note that this may break gradients e.g.

  a = tf_np.array([1, 2, 3])
  old_a_t = a.data

  with tf.GradientTape(persistent=True) as g:
    g.watch(a.data)
    b = a * 2
    a[0] = 5
  g.gradient(b.data, [a.data])  # [None]
  g.gradient(b.data, [old_a_t])  # [[2., 2., 2.]]

  Here `d_b / d_a` is `[None]` since a.data no longer points to the same
  tensor.

  Args:
    arr: array_like.
    index: scalar or 1-d integer array.
    value: value to set at index.

  Returns:
    ndarray

  Raises:
    ValueError: if `index` is not a scalar or 1-d array.
  """
    # TODO(srbs): Figure out a solution to the gradient problem.
    arr = array_creation.asarray(arr)
    index = array_creation.asarray(index)
    if index.ndim == 0:
        index = ravel(index)
    elif index.ndim > 1:
        raise ValueError('index must be a scalar or a 1-d array.')
    value = array_creation.asarray(value, dtype=arr.dtype)
    if arr.shape[len(index):] != value.shape:
        value = array_manipulation.broadcast_to(value, arr.shape[len(index):])
    prefix_t = arr.data[:index.data[0]]
    postfix_t = arr.data[index.data[0] + 1:]
    if len(index) == 1:
        arr._data = tf.concat(  # pylint: disable=protected-access
            [prefix_t, tf.expand_dims(value.data, 0), postfix_t], 0)
    else:
        subarray = arr[index.data[0]]
        _setitem(subarray, index[1:], value)
        arr._data = tf.concat(  # pylint: disable=protected-access
            [prefix_t, tf.expand_dims(subarray.data, 0), postfix_t], 0)
コード例 #2
0
 def run_test(arr, shape):
     for fn in self.array_transforms:
         arg1 = fn(arr)
         self.match(array_manipulation.broadcast_to(arg1, shape),
                    np.broadcast_to(arg1, shape))
コード例 #3
0
 def replicate(x, num_devices=2):
   return array_manipulation.broadcast_to(x, (num_devices,) + x.shape)