Beispiel #1
0
    def _check_domain_range_possibly_add_asserts(self):
        """Static check of init arg `num_rows`, possibly add asserts."""
        # Possibly add asserts.
        if self._assert_proper_shapes:
            self._num_rows = control_flow_ops.with_dependencies([
                check_ops.assert_rank(
                    self._num_rows,
                    0,
                    message="Argument num_rows must be a 0-D Tensor."),
                check_ops.assert_non_negative(
                    self._num_rows,
                    message="Argument num_rows must be non-negative."),
            ], self._num_rows)
            self._num_columns = control_flow_ops.with_dependencies([
                check_ops.assert_rank(
                    self._num_columns,
                    0,
                    message="Argument num_columns must be a 0-D Tensor."),
                check_ops.assert_non_negative(
                    self._num_columns,
                    message="Argument num_columns must be non-negative."),
            ], self._num_columns)

        # Static checks.
        if not np.issubdtype(self._num_rows.dtype, np.integer):
            raise TypeError("Argument num_rows must be integer type.  Found:"
                            " %s" % self._num_rows)

        if not np.issubdtype(self._num_columns.dtype, np.integer):
            raise TypeError(
                "Argument num_columns must be integer type.  Found:"
                " %s" % self._num_columns)

        num_rows_static = self._num_rows_static
        num_columns_static = self._num_columns_static

        if num_rows_static is not None:
            if num_rows_static.ndim != 0:
                raise ValueError(
                    "Argument num_rows must be a 0-D Tensor.  Found:"
                    " %s" % num_rows_static)

            if num_rows_static < 0:
                raise ValueError(
                    "Argument num_rows must be non-negative.  Found:"
                    " %s" % num_rows_static)
        if num_columns_static is not None:
            if num_columns_static.ndim != 0:
                raise ValueError(
                    "Argument num_columns must be a 0-D Tensor.  Found:"
                    " %s" % num_columns_static)

            if num_columns_static < 0:
                raise ValueError(
                    "Argument num_columns must be non-negative.  Found:"
                    " %s" % num_columns_static)
Beispiel #2
0
    def _check_batch_shape_possibly_add_asserts(self):
        """Static check of init arg `batch_shape`, possibly add asserts."""
        if self._batch_shape_arg is None:
            return

        # Possibly add asserts
        if self._assert_proper_shapes:
            self._batch_shape_arg = control_flow_ops.with_dependencies([
                check_ops.assert_rank(
                    self._batch_shape_arg,
                    1,
                    message="Argument batch_shape must be a 1-D Tensor."),
                check_ops.assert_non_negative(
                    self._batch_shape_arg,
                    message="Argument batch_shape must be non-negative."),
            ], self._batch_shape_arg)

        # Static checks
        if not np.issubdtype(self._batch_shape_arg.dtype, np.integer):
            raise TypeError(
                "Argument batch_shape must be integer type.  Found:"
                " %s" % self._batch_shape_arg)

        if self._batch_shape_static is None:
            return  # Cannot do any other static checks.

        if self._batch_shape_static.ndim != 1:
            raise ValueError(
                "Argument batch_shape must be a 1-D Tensor.  Found:"
                " %s" % self._batch_shape_static)

        if np.any(self._batch_shape_static < 0):
            raise ValueError(
                "Argument batch_shape must be non-negative.  Found:"
                "%s" % self._batch_shape_static)
Beispiel #3
0
    def _matmul(self, x, adjoint=False, adjoint_arg=False):
        if self._assert_proper_shapes:
            x = linalg.adjoint(x) if adjoint_arg else x
            aps = linear_operator_util.assert_compatible_matrix_dimensions(
                self, x)
            x = control_flow_ops.with_dependencies([aps], x)
        if self.is_square:
            # Note that adjoint has no effect since this matrix is self-adjoint.
            if adjoint_arg:
                output_shape = array_ops.concat([
                    array_ops.shape(x)[:-2],
                    [array_ops.shape(x)[-1],
                     array_ops.shape(x)[-2]]
                ],
                                                axis=0)
            else:
                output_shape = array_ops.shape(x)

            return self._possibly_broadcast_batch_shape(
                array_ops.zeros(shape=output_shape, dtype=x.dtype))

        x_shape = array_ops.shape(x)
        n = self._num_columns if adjoint else self._num_rows
        m = x_shape[-2] if adjoint_arg else x_shape[-1]

        output_shape = array_ops.concat([x_shape[:-2], [n, m]], axis=0)

        zeros = array_ops.zeros(shape=output_shape, dtype=x.dtype)
        return self._possibly_broadcast_batch_shape(zeros)
 def _solve(self, rhs, adjoint=False, adjoint_arg=False):
     rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, rhs)
         rhs = control_flow_ops.with_dependencies([aps], rhs)
     return rhs / self._make_multiplier_matrix(conjugate=adjoint)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
     x = linalg.adjoint(x) if adjoint_arg else x
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, x)
         x = control_flow_ops.with_dependencies([aps], x)
     return x * self._make_multiplier_matrix(conjugate=adjoint)
 def _matmul(self, x, adjoint=False, adjoint_arg=False):
     # Note that adjoint has no effect since this matrix is self-adjoint.
     x = linalg.adjoint(x) if adjoint_arg else x
     if self._assert_proper_shapes:
         aps = linear_operator_util.assert_compatible_matrix_dimensions(
             self, x)
         x = control_flow_ops.with_dependencies([aps], x)
     return self._possibly_broadcast_batch_shape(x)