Example #1
0
    def _check_domain_range_possibly_add_asserts(self):
        """Static check of init arg `num_rows`, possibly add asserts."""
        # Possibly add asserts.
        if self._assert_proper_shapes:
            self._num_rows = distribution_util.with_dependencies([
                check_ops.assert_rank(
                    self._num_rows,
                    0,
                    message="Argument num_rows must be a 0-D Tensor."),
                check_ops.assert_non_negative(
                    self._num_rows,
                    message="Argument num_rows must be non-negative."),
            ], self._num_rows)
            self._num_columns = distribution_util.with_dependencies([
                check_ops.assert_rank(
                    self._num_columns,
                    0,
                    message="Argument num_columns must be a 0-D Tensor."),
                check_ops.assert_non_negative(
                    self._num_columns,
                    message="Argument num_columns must be non-negative."),
            ], self._num_columns)

        # Static checks.
        if not np.issubdtype(self._num_rows.dtype, np.integer):
            raise TypeError("Argument num_rows must be integer type.  Found:"
                            " %s" % self._num_rows)

        if not np.issubdtype(self._num_columns.dtype, np.integer):
            raise TypeError(
                "Argument num_columns must be integer type.  Found:"
                " %s" % self._num_columns)

        num_rows_static = self._num_rows_static
        num_columns_static = self._num_columns_static

        if num_rows_static is not None:
            if num_rows_static.ndim != 0:
                raise ValueError(
                    "Argument num_rows must be a 0-D Tensor.  Found:"
                    " %s" % num_rows_static)

            if num_rows_static < 0:
                raise ValueError(
                    "Argument num_rows must be non-negative.  Found:"
                    " %s" % num_rows_static)
        if num_columns_static is not None:
            if num_columns_static.ndim != 0:
                raise ValueError(
                    "Argument num_columns must be a 0-D Tensor.  Found:"
                    " %s" % num_columns_static)

            if num_columns_static < 0:
                raise ValueError(
                    "Argument num_columns must be non-negative.  Found:"
                    " %s" % num_columns_static)
    def _check_batch_shape_possibly_add_asserts(self):
        """Static check of init arg `batch_shape`, possibly add asserts."""
        if self._batch_shape_arg is None:
            return

        # Possibly add asserts
        if self._assert_proper_shapes:
            self._batch_shape_arg = distribution_util.with_dependencies([
                check_ops.assert_rank(
                    self._batch_shape_arg,
                    1,
                    message="Argument batch_shape must be a 1-D Tensor."),
                check_ops.assert_non_negative(
                    self._batch_shape_arg,
                    message="Argument batch_shape must be non-negative."),
            ], self._batch_shape_arg)

        # Static checks
        if not np.issubdtype(self._batch_shape_arg.dtype, np.integer):
            raise TypeError(
                "Argument batch_shape must be integer type.  Found:"
                " %s" % self._batch_shape_arg)

        if self._batch_shape_static is None:
            return  # Cannot do any other static checks.

        if self._batch_shape_static.ndim != 1:
            raise ValueError(
                "Argument batch_shape must be a 1-D Tensor.  Found:"
                " %s" % self._batch_shape_static)

        if np.any(self._batch_shape_static < 0):
            raise ValueError(
                "Argument batch_shape must be non-negative.  Found:"
                "%s" % self._batch_shape_static)