예제 #1
0
    def test_setitem(self):
        # Single integer index.
        a = array_ops.array([1., 2., 3.])
        b = array_ops.array(5.)
        c = array_ops.array(10.)

        tensors = [arr.data for arr in [a, b, c]]
        with tf.GradientTape() as g:
            g.watch(tensors)
            a[1] = b + c
            loss = array_ops.sum(a)

        gradients = g.gradient(loss.data, tensors)
        self.assertSequenceEqual(
            array_ops.array(gradients[0]).tolist(), [1., 0., 1.])
        self.assertEqual(array_ops.array(gradients[1]).tolist(), 1.)
        self.assertEqual(array_ops.array(gradients[2]).tolist(), 1.)

        # Tuple index.
        a = array_ops.array([[[1., 2.], [3., 4.]], [[5., 6.],
                                                    [7., 8.]]])  # 2x2x2 array.
        b = array_ops.array([10., 11.])

        tensors = [arr.data for arr in [a, b]]
        with tf.GradientTape() as g:
            g.watch(tensors)
            a[(1, 0)] = b
            loss = array_ops.sum(a)

        gradients = g.gradient(loss.data, tensors)
        self.assertSequenceEqual(
            array_ops.array(gradients[0]).tolist(),
            [[[1., 1.], [1., 1.]], [[0., 0.], [1., 1.]]])
        self.assertEqual(array_ops.array(gradients[1]).tolist(), [1., 1.])
예제 #2
0
파일: math_ops.py 프로젝트: zhaoqiuye/trax
def nanmean(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=missing-docstring
    a = array_ops.array(a)
    if np.issubdtype(a.dtype, np.bool_) or np.issubdtype(a.dtype, np.integer):
        return array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims)
    nan_mask = logical_not(isnan(a))
    if dtype is None:
        dtype = a.dtype
    normalizer = array_ops.sum(nan_mask,
                               axis=axis,
                               dtype=dtype,
                               keepdims=keepdims)
    return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer
예제 #3
0
파일: math_ops.py 프로젝트: zhaoqiuye/trax
def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
    if dtype:
        dtype = utils.result_type(dtype)
    a = array_ops.asarray(a, dtype).data

    if offset == 0:
        a_shape = a.shape
        if a_shape.rank is not None:
            rank = len(a_shape)
            if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1
                                                       or axis2 == rank - 1):
                return utils.tensor_to_ndarray(tf.linalg.trace(a))

    a = array_ops.diagonal(a, offset, axis1, axis2)
    return array_ops.sum(a, -1, dtype)