예제 #1
0
  def _assert_self_adjoint(self):
    # Check the diagonal has non-zero imaginary, and the super and subdiagonals
    # are conjugate.

    asserts = []
    diag_message = (
        'This tridiagonal operator contained non-zero '
        'imaginary values on the diagonal.')
    off_diag_message = (
        'This tridiagonal operator has non-conjugate '
        'subdiagonal and superdiagonal.')

    if self.diagonals_format == _MATRIX:
      asserts += [check_ops.assert_equal(
          self.diagonals, linalg.adjoint(self.diagonals),
          message='Matrix was not equal to its adjoint.')]
    elif self.diagonals_format == _COMPACT:
      diagonals = ops.convert_to_tensor_v2_with_dispatch(self.diagonals)
      asserts += [linear_operator_util.assert_zero_imag_part(
          diagonals[..., 1, :], message=diag_message)]
      # Roll the subdiagonal so the shifted argument is at the end.
      subdiag = manip_ops.roll(diagonals[..., 2, :], shift=-1, axis=-1)
      asserts += [check_ops.assert_equal(
          math_ops.conj(subdiag[..., :-1]),
          diagonals[..., 0, :-1],
          message=off_diag_message)]
    else:
      asserts += [linear_operator_util.assert_zero_imag_part(
          self.diagonals[1], message=diag_message)]
      subdiag = manip_ops.roll(self.diagonals[2], shift=-1, axis=-1)
      asserts += [check_ops.assert_equal(
          math_ops.conj(subdiag[..., :-1]),
          self.diagonals[0][..., :-1],
          message=off_diag_message)]
    return control_flow_ops.group(asserts)
예제 #2
0
 def testRollShiftAndAxisMustBeSameSizeRaises(self):
   tensor = [[1, 2], [3, 4]]
   shift = [1]
   axis = [0, 1]
   with self.test_session():
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  "shift and axis must have the same size"):
       manip_ops.roll(tensor, shift, axis).eval()
예제 #3
0
 def testRollInputMustVectorHigherRaises(self):
   tensor = 7
   shift = 1
   axis = 0
   with self.test_session():
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  "input must be 1-D or higher"):
       manip_ops.roll(tensor, shift, axis).eval()
예제 #4
0
 def testRollAxisMustBeScalarOrVectorRaises(self):
   tensor = [[1, 2], [3, 4]]
   shift = 1
   axis = [[0, 1]]
   with self.test_session():
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  "axis must be a scalar or a 1-D vector"):
       manip_ops.roll(tensor, shift, axis).eval()
예제 #5
0
 def testRollAxisOutOfRangeRaises(self):
     tensor = [1, 2]
     shift = 1
     axis = 1
     with self.cached_session(use_gpu=True):
         with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
                                     "is out of range"):
             manip_ops.roll(tensor, shift, axis).eval()
예제 #6
0
 def testRollAxisOutOfRangeRaises(self):
   tensor = [1, 2]
   shift = 1
   axis = 1
   with self.test_session():
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  "is out of range"):
       manip_ops.roll(tensor, shift, axis).eval()
예제 #7
0
 def testNegativeAxis(self):
   self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
   self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
   # Make sure negative axis should be 0 <= axis + dims < dims
   with self.test_session():
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  "is out of range"):
       manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
                      3, -10).eval()
예제 #8
0
 def testRollShiftMustBeScalarOrVectorRaises(self):
   # The shift should be a scalar or 1-D, checked in kernel.
   tensor = [[1, 2], [3, 4]]
   shift = array_ops.placeholder(dtype=dtypes.int32)
   axis = 1
   with self.test_session():
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  "shift must be a scalar or a 1-D vector"):
       manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
예제 #9
0
 def testRollInputMustVectorHigherRaises(self):
   # The input should be 1-D or higher, checked in kernel.
   tensor = array_ops.placeholder(dtype=dtypes.int32)
   shift = 1
   axis = 0
   with self.test_session():
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  "input must be 1-D or higher"):
       manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
예제 #10
0
 def testRollInputMustVectorHigherRaises(self):
     # The input should be 1-D or higher, checked in kernel.
     tensor = array_ops.placeholder(dtype=dtypes.int32)
     shift = 1
     axis = 0
     with self.cached_session(use_gpu=True):
         with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
                                     "input must be 1-D or higher"):
             manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
예제 #11
0
 def testNegativeAxis(self):
   self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
   self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
   # Make sure negative axis should be 0 <= axis + dims < dims
   with self.cached_session(use_gpu=True):
     with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
                                 "is out of range"):
       manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
                      3, -10).eval()
예제 #12
0
 def testRollShiftAndAxisMustBeSameSizeRaises(self):
   # The shift and axis must be same size, checked in kernel.
   tensor = [[1, 2], [3, 4]]
   shift = array_ops.placeholder(dtype=dtypes.int32)
   axis = [0, 1]
   with self.cached_session(use_gpu=True):
     with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
                                 "shift and axis must have the same size"):
       manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
예제 #13
0
 def testRollShiftMustBeScalarOrVectorRaises(self):
   # The shift should be a scalar or 1-D, checked in kernel.
   tensor = [[1, 2], [3, 4]]
   shift = array_ops.placeholder(dtype=dtypes.int32)
   axis = 1
   with self.cached_session(use_gpu=True):
     with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
                                 "shift must be a scalar or a 1-D vector"):
       manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
예제 #14
0
 def testRollShiftAndAxisMustBeSameSizeRaises(self):
   # The shift and axis must be same size, checked in kernel.
   tensor = [[1, 2], [3, 4]]
   shift = array_ops.placeholder(dtype=dtypes.int32)
   axis = [0, 1]
   with self.test_session():
     with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                  "shift and axis must have the same size"):
       manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
예제 #15
0
 def testRollAxisMustBeScalarOrVectorRaises(self):
     # The axis should be a scalar or 1-D, checked in kernel.
     tensor = [[1, 2], [3, 4]]
     shift = 1
     axis = array_ops.placeholder(dtype=dtypes.int32)
     with self.cached_session():
         with self.assertRaisesRegexp(
                 errors_impl.InvalidArgumentError,
                 "axis must be a scalar or a 1-D vector"):
             manip_ops.roll(tensor, shift,
                            axis).eval(feed_dict={axis: [[0, 1]]})
예제 #16
0
def fftshift(x, axes=None, name=None):
    """Shift the zero-frequency component to the center of the spectrum.
    This function swaps half-spaces for all axes listed (defaults to all).
    Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even.
    @compatibility(numpy)
    Equivalent to numpy.fft.fftshift.
    https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.fftshift.html
    @end_compatibility
    For example:
    ```python
    x = tf.signal.fftshift([ 0.,  1.,  2.,  3.,  4., -5., -4., -3., -2., -1.])
    x.numpy() # array([-5., -4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.])
    ```
    Args:
    x: `Tensor`, input tensor.
    axes: `int` or shape `tuple`, optional Axes over which to shift.  Default is
      None, which shifts all axes.
    name: An optional name for the operation.
    Returns:
    A `Tensor`, The shifted tensor.
    """
    with _ops.name_scope(name, "fftshift") as name:
        x = _ops.convert_to_tensor(x)
    if axes is None:
        axes = tuple(range(x.shape.ndims))
        shift = [int(dim // 2) for dim in x.shape]
    elif isinstance(axes, int):
        shift = int(x.shape[axes] // 2)
    else:
        shift = [int((x.shape[ax]) // 2) for ax in axes]

    return manip_ops.roll(x, shift, axes)
예제 #17
0
def ifftshift(x, axes=None, name=None):
    """The inverse of fftshift.
    Although identical for even-length x,
    the functions differ by one sample for odd-length x.
    @compatibility(numpy)
    Equivalent to numpy.fft.ifftshift.
    https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.ifftshift.html
    @end_compatibility
    For example:
    ```python
    x = tf.signal.ifftshift([[ 0.,  1.,  2.],[ 3.,  4., -4.],[-3., -2., -1.]])
    x.numpy() # array([[ 4., -4.,  3.],[-2., -1., -3.],[ 1.,  2.,  0.]])
    ```
    Args:
    x: `Tensor`, input tensor.
    axes: `int` or shape `tuple` Axes over which to calculate. Defaults to None,
      which shifts all axes.
    name: An optional name for the operation.
    Returns:
    A `Tensor`, The shifted tensor.
    """
    with _ops.name_scope(name, "ifftshift") as name:
        x = _ops.convert_to_tensor(x)
    if axes is None:
        axes = tuple(range(x.shape.ndims))
        shift = [-int(dim // 2) for dim in x.shape]
    elif isinstance(axes, int):
        shift = -int(x.shape[axes] // 2)
    else:
        shift = [-int(x.shape[ax] // 2) for ax in axes]

    return manip_ops.roll(x, shift, axes)
예제 #18
0
 def _testRoll(self, a, shift, axis):
   with self.session() as session:
     with self.test_scope():
       p = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
       output = manip_ops.roll(a, shift, axis)
     result = session.run(output, {p: a})
     self.assertAllEqual(result, np.roll(a, shift, axis))
  def build_operator_and_matrix(
      self, build_info, dtype, use_placeholder,
      ensure_self_adjoint_and_pd=False,
      diagonals_format='sequence'):
    shape = list(build_info.shape)

    # Ensure that diagonal has large enough values. If we generate a
    # self adjoint PD matrix, then the diagonal will be dominant guaranteeing
    # positive definitess.
    diag = linear_operator_test_util.random_sign_uniform(
        shape[:-1], minval=4., maxval=6., dtype=dtype)
    # We'll truncate these depending on the format
    subdiag = linear_operator_test_util.random_sign_uniform(
        shape[:-1], minval=1., maxval=2., dtype=dtype)
    if ensure_self_adjoint_and_pd:
      # Abs on complex64 will result in a float32, so we cast back up.
      diag = math_ops.cast(math_ops.abs(diag), dtype=dtype)
      # The first element of subdiag is ignored. We'll add a dummy element
      # to superdiag to pad it.
      superdiag = math_ops.conj(subdiag)
      superdiag = manip_ops.roll(superdiag, shift=-1, axis=-1)
    else:
      superdiag = linear_operator_test_util.random_sign_uniform(
          shape[:-1], minval=1., maxval=2., dtype=dtype)

    matrix_diagonals = array_ops.stack(
        [superdiag, diag, subdiag], axis=-2)
    matrix = gen_array_ops.matrix_diag_v3(
        matrix_diagonals,
        k=(-1, 1),
        num_rows=-1,
        num_cols=-1,
        align='LEFT_RIGHT',
        padding_value=0.)

    if diagonals_format == 'sequence':
      diagonals = [superdiag, diag, subdiag]
    elif diagonals_format == 'compact':
      diagonals = array_ops.stack([superdiag, diag, subdiag], axis=-2)
    elif diagonals_format == 'matrix':
      diagonals = matrix

    lin_op_diagonals = diagonals

    if use_placeholder:
      if diagonals_format == 'sequence':
        lin_op_diagonals = [array_ops.placeholder_with_default(
            d, shape=None) for d in lin_op_diagonals]
      else:
        lin_op_diagonals = array_ops.placeholder_with_default(
            lin_op_diagonals, shape=None)

    operator = linalg_lib.LinearOperatorTridiag(
        diagonals=lin_op_diagonals,
        diagonals_format=diagonals_format,
        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
        is_positive_definite=True if ensure_self_adjoint_and_pd else None)
    return operator, matrix
예제 #20
0
 def _testGradient(self, np_input, shift, axis):
   with self.test_session():
     inx = constant_op.constant(np_input.tolist())
     xs = list(np_input.shape)
     y = manip_ops.roll(inx, shift, axis)
     # Expected y's shape to be the same
     ys = xs
     jacob_t, jacob_n = gradient_checker.compute_gradient(
         inx, xs, y, ys, x_init_value=np_input)
     self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
예제 #21
0
 def _testGradient(self, np_input, shift, axis):
     with self.cached_session(use_gpu=True):
         inx = constant_op.constant(np_input.tolist())
         xs = list(np_input.shape)
         y = manip_ops.roll(inx, shift, axis)
         # Expected y's shape to be the same
         ys = xs
         jacob_t, jacob_n = gradient_checker.compute_gradient(
             inx, xs, y, ys, x_init_value=np_input)
         self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
예제 #22
0
 def _construct_adjoint_diagonals(self, diagonals):
   # Constructs adjoint tridiagonal matrix from diagonals.
   if self.diagonals_format == _SEQUENCE:
     diagonals = [math_ops.conj(d) for d in reversed(diagonals)]
     # The subdiag and the superdiag swap places, so we need to shift the
     # padding argument.
     diagonals[0] = manip_ops.roll(diagonals[0], shift=-1, axis=-1)
     diagonals[2] = manip_ops.roll(diagonals[2], shift=1, axis=-1)
     return diagonals
   elif self.diagonals_format == _MATRIX:
     return linalg.adjoint(diagonals)
   else:
     diagonals = math_ops.conj(diagonals)
     superdiag, diag, subdiag = array_ops.unstack(
         diagonals, num=3, axis=-2)
     # The subdiag and the superdiag swap places, so we need
     # to shift all arguments.
     new_superdiag = manip_ops.roll(subdiag, shift=-1, axis=-1)
     new_subdiag = manip_ops.roll(superdiag, shift=1, axis=-1)
     return array_ops.stack([new_superdiag, diag, new_subdiag], axis=-2)
예제 #23
0
def tridiagonal_solve(diagonals,
                      rhs,
                      diagonals_format='compact',
                      transpose_rhs=False,
                      conjugate_rhs=False,
                      name=None,
                      partial_pivoting=True):
    r"""Solves tridiagonal systems of equations.

  The input can be supplied in various formats: `matrix`, `sequence` and
  `compact`, specified by the `diagonals_format` arg.

  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
  two inner-most dimensions representing the square tridiagonal matrices.
  Elements outside of the three diagonals will be ignored.

  In `sequence` format, `diagonals` are supplied as a tuple or list of three
  tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing
  superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either
  `M-1` or `M`; in the latter case, the last element of superdiagonal and the
  first element of subdiagonal will be ignored.

  In `compact` format the three diagonals are brought together into one tensor
  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.

  The `compact` format is recommended as the one with best performance. In case
  you need to cast a tensor into a compact format manually, use `tf.gather_nd`.
  An example for a tensor of shape [m, m]:

  ```python
  rhs = tf.constant([...])
  matrix = tf.constant([[...]])
  m = matrix.shape[0]
  dummy_idx = [0, 0]  # An arbitrary element to use as a dummy
  indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx],  # Superdiagonal
           [[i, i] for i in range(m)],                          # Diagonal
           [dummy_idx] + [[i + 1, i] for i in range(m - 1)]]    # Subdiagonal
  diagonals=tf.gather_nd(matrix, indices)
  x = tf.linalg.tridiagonal_solve(diagonals, rhs)
  ```

  Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or
  `[..., M, K]`. The latter allows to simultaneously solve K systems with the
  same left-hand sides and K different right-hand sides. If `transpose_rhs`
  is set to `True` the expected shape is `[..., M]` or `[..., K, M]`.

  The batch dimensions, denoted as `...`, must be the same in `diagonals` and
  `rhs`.

  The output is a tensor of the same shape as `rhs`: either `[..., M]` or
  `[..., M, K]`.

  The op isn't guaranteed to raise an error if the input matrix is not
  invertible. `tf.debugging.check_numerics` can be applied to the output to
  detect invertibility problems.

  **Note**: with large batch sizes, the computation on the GPU may be slow, if
  either `partial_pivoting=True` or there are multiple right-hand sides
  (`K > 1`). If this issue arises, consider if it's possible to disable pivoting
  and have `K = 1`, or, alternatively, consider using CPU.

  On CPU, solution is computed via Gaussian elimination with or without partial
  pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE
  library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv

  Args:
    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
      shape depends of `diagonals_format`, see description above. Must be
      `float32`, `float64`, `complex64`, or `complex128`.
    rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as
      `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
      statically, `rhs` will be treated as a matrix rather than a vector.
    diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
      `compact`.
    transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect
      if the shape of rhs is [..., M]).
    conjugate_rhs: If `True`, `rhs` is conjugated before solving.
    name:  A name to give this `Op` (optional).
    partial_pivoting: whether to perform partial pivoting. `True` by default.
      Partial pivoting makes the procedure more stable, but slower. Partial
      pivoting is unnecessary in some cases, including diagonally dominant and
      symmetric positive definite matrices (see e.g. theorem 9.12 in [1]).

  Returns:
    A `Tensor` of shape [..., M] or [..., M, K] containing the solutions.

  Raises:
    ValueError: An unsupported type is provided as input, or when the input
    tensors have incorrect shapes.

  [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
  Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.

  """
    if diagonals_format == 'compact':
        return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                                 conjugate_rhs,
                                                 partial_pivoting, name)

    if diagonals_format == 'sequence':
        if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3:
            raise ValueError(
                'Expected diagonals to be a sequence of length 3.')

        superdiag, maindiag, subdiag = diagonals
        if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])
                or not superdiag.shape[:-1].is_compatible_with(
                    maindiag.shape[:-1])):
            raise ValueError(
                'Tensors representing the three diagonals must have the same shape,'
                'except for the last dimension, got {}, {}, {}'.format(
                    subdiag.shape, maindiag.shape, superdiag.shape))

        m = tensor_shape.dimension_value(maindiag.shape[-1])

        def pad_if_necessary(t, name, last_dim_padding):
            n = tensor_shape.dimension_value(t.shape[-1])
            if not n or n == m:
                return t
            if n == m - 1:
                paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] +
                            [last_dim_padding])
                return array_ops.pad(t, paddings)
            raise ValueError(
                'Expected {} to be have length {} or {}, got {}.'.format(
                    name, m, m - 1, n))

        subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0])
        superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1])

        diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2)
        return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                                 conjugate_rhs,
                                                 partial_pivoting, name)

    if diagonals_format == 'matrix':
        m1 = tensor_shape.dimension_value(diagonals.shape[-1])
        m2 = tensor_shape.dimension_value(diagonals.shape[-2])
        if m1 and m2 and m1 != m2:
            raise ValueError(
                'Expected last two dimensions of diagonals to be same, got {} and {}'
                .format(m1, m2))
        m = m1 or m2
        diagonals = gen_array_ops.matrix_diag_part_v2(diagonals,
                                                      k=(-1, 1),
                                                      padding_value=0.)
        # matrix_diag_part pads at the end. Because the subdiagonal has the
        # convention of having the padding in the front, we need to rotate the last
        # Tensor.
        superdiag, d, subdiag = array_ops.unstack(diagonals, num=3, axis=-2)
        subdiag = manip_ops.roll(subdiag, shift=1, axis=-1)
        diagonals = array_ops.stack((superdiag, d, subdiag), axis=-2)
        return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
                                                 conjugate_rhs,
                                                 partial_pivoting, name)

    raise ValueError(
        'Unrecognized diagonals_format: {}'.format(diagonals_format))
예제 #24
0
파일: losses.py 프로젝트: nlaanait/stemdl
def fftshift(tensor, tens_format='NCHW'):
    dims = [2,3] if tens_format == 'NCHW' else [1,2]
    shift = [int((tensor.shape[dim]) // 2) for dim in dims]
    shift_tensor = manip_ops.roll(tensor, shift, dims)
    return shift_tensor
예제 #25
0
 def _testRoll(self, np_input, shift, axis):
     expected_roll = np.roll(np_input, shift, axis)
     with self.cached_session(use_gpu=True):
         roll = manip_ops.roll(np_input, shift, axis)
         self.assertAllEqual(roll, expected_roll)
예제 #26
0
 def _testRoll(self, np_input, shift, axis):
   expected_roll = np.roll(np_input, shift, axis)
   with self.test_session():
     roll = manip_ops.roll(np_input, shift, axis)
     self.assertAllEqual(roll.eval(), expected_roll)
예제 #27
0
 def testInvalidShiftAndAxisNotEqualShape(self):
     # The shift and axis must be same size, checked in shape function.
     with self.assertRaisesRegex(ValueError, "both shapes must be equal"):
         manip_ops.roll([[1, 2], [3, 4]], [1], [0, 1])
예제 #28
0
 def forward(self, image, param):
     return manip_ops.roll(image, param, axis=1)
예제 #29
0
 def backward(self, image, param):
     return manip_ops.roll(image, -param, axis=1)
예제 #30
0
 def testInvalidShiftShape(self):
     # The shift should be a scalar or 1-D, checked in shape function.
     with self.assertRaisesRegex(
             ValueError, "Shape must be at most rank 1 but is rank 2"):
         manip_ops.roll([[1, 2], [3, 4]], [[0, 1]], 1)
예제 #31
0
 def testInvalidShiftShape(self):
   # The shift should be a scalar or 1-D, checked in shape function.
   with self.assertRaisesRegexp(
       ValueError, "Shape must be at most rank 1 but is rank 2"):
     manip_ops.roll([[1, 2], [3, 4]], [[0, 1]], 1)
예제 #32
0
 def testInvalidInputShape(self):
   # The input should be 1-D or higher, checked in shape function.
   with self.assertRaisesRegexp(
       ValueError, "Shape must be at least rank 1 but is rank 0"):
     manip_ops.roll(7, 1, 0)
def _temptf_ifft_shift(x):
    # taken from https://github.com/tensorflow/tensorflow/pull/27075/files
    shift = [
        -tf.cast(tf.shape(x)[ax] // 2, tf.int32) for ax in FOURIER_SHIFT_AXES
    ]
    return manip_ops.roll(x, shift, FOURIER_SHIFT_AXES)
예제 #34
0
def _RollGrad(op, grad):
  # The gradient is just the roll reversed
  shift = op.inputs[1]
  axis = op.inputs[2]
  roll_grad = manip_ops.roll(grad, -shift, axis)
  return roll_grad, None, None
예제 #35
0
def _RollGrad(op, grad):
    # The gradient is just the roll reversed
    shift = op.inputs[1]
    axis = op.inputs[2]
    roll_grad = manip_ops.roll(grad, -shift, axis)
    return roll_grad, None, None
예제 #36
0
 def testInvalidInputShape(self):
     # The input should be 1-D or higher, checked in shape function.
     with self.assertRaisesRegex(
             ValueError, "Shape must be at least rank 1 but is rank 0"):
         manip_ops.roll(7, 1, 0)
예제 #37
0
 def testInvalidShiftAndAxisNotEqualShape(self):
   # The shift and axis must be same size, checked in shape function.
   with self.assertRaisesRegexp(ValueError, "both shapes must be equal"):
     manip_ops.roll([[1, 2], [3, 4]], [1], [0, 1])