def _setitem(arr, index, value): """Sets the `value` at `index` in the array `arr`. This works by replacing the slice at `index` in the tensor with `value`. Since tensors are immutable, this builds a new tensor using the `tf.concat` op. Currently, only 0-d and 1-d indices are supported. Note that this may break gradients e.g. a = tf_np.array([1, 2, 3]) old_a_t = a.data with tf.GradientTape(persistent=True) as g: g.watch(a.data) b = a * 2 a[0] = 5 g.gradient(b.data, [a.data]) # [None] g.gradient(b.data, [old_a_t]) # [[2., 2., 2.]] Here `d_b / d_a` is `[None]` since a.data no longer points to the same tensor. Args: arr: array_like. index: scalar or 1-d integer array. value: value to set at index. Returns: ndarray Raises: ValueError: if `index` is not a scalar or 1-d array. """ # TODO(srbs): Figure out a solution to the gradient problem. arr = array_creation.asarray(arr) index = array_creation.asarray(index) if index.ndim == 0: index = ravel(index) elif index.ndim > 1: raise ValueError('index must be a scalar or a 1-d array.') value = array_creation.asarray(value, dtype=arr.dtype) if arr.shape[len(index):] != value.shape: value = array_manipulation.broadcast_to(value, arr.shape[len(index):]) prefix_t = arr.data[:index.data[0]] postfix_t = arr.data[index.data[0] + 1:] if len(index) == 1: arr._data = tf.concat( # pylint: disable=protected-access [prefix_t, tf.expand_dims(value.data, 0), postfix_t], 0) else: subarray = arr[index.data[0]] _setitem(subarray, index[1:], value) arr._data = tf.concat( # pylint: disable=protected-access [prefix_t, tf.expand_dims(subarray.data, 0), postfix_t], 0)
def run_test(arr, shape): for fn in self.array_transforms: arg1 = fn(arr) self.match(array_manipulation.broadcast_to(arg1, shape), np.broadcast_to(arg1, shape))
def replicate(x, num_devices=2): return array_manipulation.broadcast_to(x, (num_devices,) + x.shape)