示例#1
0
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
    """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
    batch_shape_static = tensorshape_util.constant_value_as_shape(new_shape)
    if tensorshape_util.is_fully_defined(batch_shape_static):
        return np.int32(batch_shape_static), batch_shape_static, []
    with tf.name_scope(name or 'calculate_reshape'):
        original_size = tf.reduce_prod(original_shape)
        implicit_dim = tf.equal(new_shape, -1)
        size_implicit_dim = (original_size //
                             tf.maximum(1, -tf.reduce_prod(new_shape)))
        expanded_new_shape = tf.where(  # Assumes exactly one `-1`.
            implicit_dim, size_implicit_dim, new_shape)
        validations = [] if not validate else [  # pylint: disable=g-long-ternary
            assert_util.assert_rank(
                original_shape, 1, message='Original shape must be a vector.'),
            assert_util.assert_rank(
                new_shape, 1, message='New shape must be a vector.'),
            assert_util.assert_less_equal(
                tf.math.count_nonzero(implicit_dim, dtype=tf.int32),
                1,
                message='At most one dimension can be unknown.'),
            assert_util.assert_positive(
                expanded_new_shape, message='Shape elements must be >=-1.'),
            assert_util.assert_equal(tf.reduce_prod(expanded_new_shape),
                                     original_size,
                                     message='Shape sizes do not match.'),
        ]
        return expanded_new_shape, batch_shape_static, validations
示例#2
0
def _lu_reconstruct_assertions(lower_upper, perm, validate_args):
    """Returns list of assertions related to `lu_reconstruct` assumptions."""
    assertions = []

    message = 'Input `lower_upper` must have at least 2 dimensions.'
    if tensorshape_util.rank(lower_upper.shape) is not None:
        if tensorshape_util.rank(lower_upper.shape) < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(lower_upper,
                                             rank=2,
                                             message=message))

    message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
    if (tensorshape_util.rank(lower_upper.shape) is not None
            and tensorshape_util.rank(perm.shape) is not None):
        if (tensorshape_util.rank(lower_upper.shape) !=
                tensorshape_util.rank(perm.shape) + 1):
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(lower_upper,
                                    rank=tf.rank(perm) + 1,
                                    message=message))

    message = '`lower_upper` must be square.'
    if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]):
        if lower_upper.shape[-2] != lower_upper.shape[-1]:
            raise ValueError(message)
    elif validate_args:
        m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2)
        assertions.append(assert_util.assert_equal(m, n, message=message))

    return assertions
示例#3
0
def _maybe_check_valid_map_values(map_values, validate_args):
    """Validate `map_values` if `validate_args`==True."""
    assertions = []

    message = 'Rank of map_values must be 1.'
    if tensorshape_util.rank(map_values.shape) is not None:
        if tensorshape_util.rank(map_values.shape) != 1:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(map_values, 1, message=message))

    message = 'Size of map_values must be greater than 0.'
    if tensorshape_util.num_elements(map_values.shape) is not None:
        if tensorshape_util.num_elements(map_values.shape) == 0:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(tf.size(map_values), 0,
                                       message=message))

    if validate_args:
        assertions.append(
            assert_util.assert_equal(
                tf.math.is_strictly_increasing(map_values),
                True,
                message='map_values is not strictly increasing.'))

    return assertions
示例#4
0
  def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
    """Check whether event_ndims is atleast min_event_ndims."""
    event_ndims = tf.convert_to_tensor(event_ndims, name='event_ndims')
    event_ndims_ = tf.get_static_value(event_ndims)
    assertions = []

    if not dtype_util.is_integer(event_ndims.dtype):
      raise ValueError('Expected integer dtype, got dtype {}'.format(
          event_ndims.dtype))

    if event_ndims_ is not None:
      if tensorshape_util.rank(event_ndims.shape) != 0:
        raise ValueError('Expected scalar event_ndims, got shape {}'.format(
            event_ndims.shape))
      if min_event_ndims > event_ndims_:
        raise ValueError('event_ndims ({}) must be larger than '
                         'min_event_ndims ({})'.format(event_ndims_,
                                                       min_event_ndims))
    elif self.validate_args:
      assertions += [
          assert_util.assert_greater_equal(event_ndims, min_event_ndims)
      ]

    if tensorshape_util.is_fully_defined(event_ndims.shape):
      if tensorshape_util.rank(event_ndims.shape) != 0:
        raise ValueError('Expected scalar shape, got ndims {}'.format(
            tensorshape_util.rank(event_ndims.shape)))

    elif self.validate_args:
      assertions += [
          assert_util.assert_rank(event_ndims, 0, message='Expected scalar.')
      ]
    return assertions
示例#5
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         # Avoid computing intermediates needed to construct the assertions.
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self._batch_shape_unexpanded):
         implicit_dim_mask = ps.equal(self._batch_shape_unexpanded, -1)
         assertions.append(
             assert_util.assert_rank(self._batch_shape_unexpanded,
                                     1,
                                     message='New shape must be a vector.'))
         assertions.append(
             assert_util.assert_less_equal(
                 tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32),
                 1,
                 message='At most one dimension can be unknown.'))
         assertions.append(
             assert_util.assert_non_negative(
                 self._batch_shape_unexpanded + 1,
                 message='Shape elements must be >=-1.'))
         # Check that the old and new shapes are the same size.
         expanded_new_shape, original_size = self._calculate_new_shape()
         new_size = ps.reduce_prod(expanded_new_shape)
         assertions.append(
             assert_util.assert_equal(new_size,
                                      tf.cast(original_size,
                                              new_size.dtype),
                                      message='Shape sizes do not match.'))
     return assertions
示例#6
0
def _maybe_validate_perm(perm, validate_args, name=None):
    """Checks that `perm` is valid."""
    with tf.name_scope(name or 'maybe_validate_perm'):
        assertions = []
        if not dtype_util.is_integer(perm.dtype):
            raise TypeError('`perm` must be integer type')

        msg = '`perm` must be a vector.'
        if tensorshape_util.rank(perm.shape) is not None:
            if tensorshape_util.rank(perm.shape) != 1:
                raise ValueError(msg[:-1] + ', saw rank: {}.'.format(
                    tensorshape_util.rank(perm.shape)))
        elif validate_args:
            assertions += [assert_util.assert_rank(perm, 1, message=msg)]

        perm_ = tf.get_static_value(perm)
        msg = '`perm` must be a valid permutation vector.'
        if perm_ is not None:
            if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
                raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
        elif validate_args:
            assertions += [
                assert_util.assert_equal(tf.sort(perm),
                                         tf.range(tf.size(input=perm)),
                                         message=msg)
            ]

        return assertions
示例#7
0
def _maybe_validate_rightmost_transposed_ndims(
    rightmost_transposed_ndims, validate_args, name=None):
  """Checks that `rightmost_transposed_ndims` is valid."""
  with tf.name_scope(name or 'maybe_validate_rightmost_transposed_ndims'):
    assertions = []
    if not dtype_util.is_integer(rightmost_transposed_ndims.dtype):
      raise TypeError('`rightmost_transposed_ndims` must be integer type.')

    if tensorshape_util.rank(rightmost_transposed_ndims.shape) is not None:
      if tensorshape_util.rank(rightmost_transposed_ndims.shape) != 0:
        raise ValueError('`rightmost_transposed_ndims` must be a scalar, '
                         'saw rank: {}.'.format(
                             tensorshape_util.rank(
                                 rightmost_transposed_ndims.shape)))
    elif validate_args:
      assertions += [assert_util.assert_rank(rightmost_transposed_ndims, 0)]

    rightmost_transposed_ndims_ = tf.get_static_value(
        rightmost_transposed_ndims)
    msg = '`rightmost_transposed_ndims` must be non-negative.'
    if rightmost_transposed_ndims_ is not None:
      if rightmost_transposed_ndims_ < 0:
        raise ValueError(msg[:-1] + ', saw: {}.'.format(
            rightmost_transposed_ndims_))
    elif validate_args:
      assertions += [
          assert_util.assert_non_negative(
              rightmost_transposed_ndims, message=msg)
      ]

    return assertions
    def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                       validate_args, name):
        """Helper to __init__ which ensures override batch/event_shape are valid."""
        if override_shape is None:
            override_shape = []

        override_shape = tf.convert_to_tensor(value=override_shape,
                                              dtype=tf.int32,
                                              name=name)

        if not dtype_util.is_integer(override_shape.dtype):
            raise TypeError("shape override must be an integer")

        override_is_scalar = _is_scalar_from_shape_tensor(override_shape)
        if tf.get_static_value(override_is_scalar):
            return self._empty

        dynamic_assertions = []

        if tensorshape_util.rank(override_shape.shape) is not None:
            if tensorshape_util.rank(override_shape.shape) != 1:
                raise ValueError("shape override must be a vector")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_rank(
                    override_shape,
                    1,
                    message="shape override must be a vector")
            ]

        if tf.get_static_value(override_shape) is not None:
            if any(s < 0 for s in tf.get_static_value(override_shape)):
                raise ValueError(
                    "shape override must have non-negative elements")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_non_negative(
                    override_shape,
                    message="shape override must have non-negative elements")
            ]

        is_both_nonscalar = prefer_static.logical_and(
            prefer_static.logical_not(base_is_scalar),
            prefer_static.logical_not(override_is_scalar))
        if tf.get_static_value(is_both_nonscalar) is not None:
            if tf.get_static_value(is_both_nonscalar):
                raise ValueError("base distribution not scalar")
        elif validate_args:
            dynamic_assertions += [
                assert_util.assert_equal(
                    is_both_nonscalar,
                    False,
                    message="base distribution not scalar")
            ]

        if not dynamic_assertions:
            return override_shape
        return distribution_util.with_dependencies(dynamic_assertions,
                                                   override_shape)
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self._batch_shape_parameter):
     assertions.append(assert_util.assert_rank(
         self._batch_shape_parameter, 1,
         message='Batch shape must be a vector.'))
     assertions.append(assert_util.assert_non_negative(
         self._batch_shape_parameter,
         message='Shape elements must be >-1.'))
   return assertions
示例#10
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []
    assertions = []

    if self._num_steps is not None:
      if is_init != tensor_util.is_ref(self._num_steps):
        assertions.append(assert_util.assert_rank(
            self._num_steps, 0,
            message='Argument `num_steps` must be a scalar'))
        assertions.append(assert_util.assert_positive(
            self._num_steps, message='Argument `num_steps` must be positive'))

    return assertions
def _ensure_step_size_is_scalar(step_size, validate_args):
  """Ensures that the step size is a scalar `Tensor`."""
  if tf.nest.is_nested(step_size):
    raise ValueError('Step size must be a scalar. Got: {}'.format(step_size))
  rank = ps.rank(step_size)
  rank_ = tf.get_static_value(rank)
  if rank_ is not None:
    if rank_ != 0:
      raise ValueError('Step size must be a scalar. Got: {}'.format(step_size))
  elif validate_args:
    with tf.control_dependencies(
        [assert_util.assert_rank(step_size, 0, 'Step size must be a scalar.')]):
      return tf.identity(step_size)
  return step_size
def _maybe_validate_rightmost_transposed_ndims(
        initial_rightmost_transposed_ndims,
        rightmost_transposed_ndims,
        validate_args,
        name=None):
    """Checks that `rightmost_transposed_ndims` is valid."""
    with tf.name_scope(name or 'maybe_validate_rightmost_transposed_ndims'):
        assertions = []

        if tensorshape_util.rank(rightmost_transposed_ndims.shape) is not None:
            if tensorshape_util.rank(rightmost_transposed_ndims.shape) != 0:
                raise ValueError(
                    '`rightmost_transposed_ndims` must be a scalar, '
                    'saw rank: {}.'.format(
                        tensorshape_util.rank(
                            rightmost_transposed_ndims.shape)))
        elif validate_args:
            assertions += [
                assert_util.assert_rank(rightmost_transposed_ndims, 0),
                assert_util.assert_equal(
                    rightmost_transposed_ndims,
                    initial_rightmost_transposed_ndims,
                    message='`rightmost_transposed_ndims` must not change '
                    'from the value set when the `Transpose` '
                    'bijector was constructed.')
            ]

        rightmost_transposed_ndims_ = tf.get_static_value(
            rightmost_transposed_ndims)
        msg = '`rightmost_transposed_ndims` must be non-negative.'
        if rightmost_transposed_ndims_ is not None:
            if rightmost_transposed_ndims_ < 0:
                raise ValueError(
                    msg[:-1] +
                    ', saw: {}.'.format(rightmost_transposed_ndims_))
        elif validate_args:
            assertions += [
                assert_util.assert_non_negative(rightmost_transposed_ndims,
                                                message=msg)
            ]

        return assertions
def _maybe_validate_perm(initial_rightmost_transposed_ndims,
                         perm,
                         validate_args,
                         name=None):
    """Checks that `perm` is valid."""
    with tf.name_scope(name or 'maybe_validate_perm'):
        assertions = []
        if not dtype_util.is_integer(perm.dtype):
            raise TypeError('`perm` must be integer type')

        msg = '`perm` must be a vector.'
        if tensorshape_util.rank(perm.shape) is not None:
            if tensorshape_util.rank(perm.shape) != 1:
                raise ValueError(msg[:-1] + ', saw rank: {}.'.format(
                    tensorshape_util.rank(perm.shape)))
        elif validate_args:
            assertions += [
                assert_util.assert_rank(perm, 1, message=msg),
                assert_util.assert_equal(
                    tf.size(perm),
                    initial_rightmost_transposed_ndims,
                    message='The number of elements of `perm` must not '
                    'change from the value set when the `Transpose` '
                    'bijector was constructed.')
            ]

        perm_ = tf.get_static_value(perm)
        msg = '`perm` must be a valid permutation vector.'
        if perm_ is not None:
            if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
                raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
        elif validate_args:
            assertions += [
                assert_util.assert_equal(tf.sort(perm),
                                         tf.range(
                                             tf.size(perm,
                                                     out_type=perm.dtype)),
                                         message=msg)
            ]

        return assertions
示例#14
0
def kendalls_tau(y_true, y_pred, name=None):
    """Computes Kendall's Tau for two ordered lists.

  Kendall's Tau measures the correlation between ordinal rankings. This
  implementation is similar to the one used in scipy.stats.kendalltau.
  The provided values may be of any type that is sortable, with the
  argsort indices indicating the true or proposed ordinal sequence.

  Args:
    y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking.
    y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the
      same N items.
    name: Optional Python `str` name for ops created by this method.
      Default value: `None` (i.e., 'kendalls_tau').

  Returns:
    kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores
      ordering of ties, as a `float32` scalar Tensor.
  """
    with tf.name_scope(name or 'kendalls_tau'):
        in_type = dtype_util.common_dtype([y_true, y_pred],
                                          dtype_hint=tf.float32)
        y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type)
        y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type)
        tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape)
        assertions = [
            assert_util.assert_rank(y_true, 1),
            assert_util.assert_greater(
                ps.size(y_true), 1, 'Ordering requires at least 2 elements.')
        ]
        with tf.control_dependencies(assertions):
            lexa = lexicographical_indirect_sort(y_true, y_pred)

        # See A Computer Method for Calculating Kendall's Tau with Ungrouped Data
        # by William Night, Journal of the American Statistical Association,
        # Jun., 1966, Vol. 61, No. 314, Part 1 (Jun., 1966), pp. 436-439
        # for notation https://www.jstor.org/stable/2282833

        def jointly_tied_pairs_body(first, t, i):
            not_equal = tf.math.logical_or(
                tf.not_equal(y_true[lexa[first]], y_true[lexa[i]]),
                tf.not_equal(y_pred[lexa[first]], y_pred[lexa[i]]))
            return (tf.where(not_equal, i, first),
                    tf.where(not_equal,
                             t + ((i - first) * (i - first - 1)) // 2,
                             t), i + 1)

        n = ps.size0(y_true)
        first, t, _ = tf.while_loop(cond=lambda first, t, i: i < n,
                                    body=jointly_tied_pairs_body,
                                    loop_vars=(0, 0, 1))
        t += ((n - first) * (n - first - 1)) // 2

        def ties_y_true_body(first, v, i):
            not_equal = tf.not_equal(y_true[lexa[first]], y_true[lexa[i]])
            return (tf.where(not_equal, i, first),
                    tf.where(not_equal,
                             v + ((i - first) * (i - first - 1)) // 2,
                             v), i + 1)

        first, v, _ = tf.while_loop(cond=lambda first, v, i: i < n,
                                    body=ties_y_true_body,
                                    loop_vars=(0, 0, 1))
        v += ((n - first) * (n - first - 1)) // 2

        # count exchanges
        exchanges, newperm = iterative_mergesort(y_pred, lexa)

        def ties_in_y_pred_body(first, u, i):
            not_equal = tf.not_equal(y_pred[newperm[first]],
                                     y_pred[newperm[i]])
            return (tf.where(not_equal, i, first),
                    tf.where(not_equal,
                             u + ((i - first) * (i - first - 1)) // 2,
                             u), i + 1)

        first, u, _ = tf.while_loop(cond=lambda first, u, i: i < n,
                                    body=ties_in_y_pred_body,
                                    loop_vars=(0, 0, 1))
        u += ((n - first) * (n - first - 1)) // 2
        n0 = (n * (n - 1)) // 2
        assertions = [
            assert_util.assert_less(v, tf.cast(n0, tf.int32),
                                    'All ranks are ties for y_true.'),
            assert_util.assert_less(u, tf.cast(n0, tf.int32),
                                    'All ranks are ties for y_pred.')
        ]
        with tf.control_dependencies(assertions):
            return (tf.cast(n0 - (u + v - t), tf.float32) -
                    2.0 * tf.cast(exchanges, tf.float32)) / tf.math.sqrt(
                        tf.cast(n0 - v, tf.float32) *
                        tf.cast(n0 - u, tf.float32))
示例#15
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # For `logits` and `probs`, we only want to have an assertion on what the
        # user actually passed. For now, we access the underlying categorical's
        # _logits and _probs directly. After the 2019-10-01 deprecation, it would
        # also work to use .logits() and .probs().
        logits = self._categorical._logits
        probs = self._categorical._probs
        outcomes = self._outcomes
        validate_args = self._validate_args

        # Build all shape and dtype checks during the `is_init` call.
        if is_init:

            def validate_equal_last_dim(tensor_a, tensor_b, message):
                event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
                event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
                if event_size_a is not None and event_size_b is not None:
                    if event_size_a != event_size_b:
                        raise ValueError(message)
                elif validate_args:
                    return assert_util.assert_equal(tf.shape(tensor_a)[-1],
                                                    tf.shape(tensor_b)[-1],
                                                    message=message)

            message = 'Size of outcomes must be greater than 0.'
            if tensorshape_util.num_elements(outcomes.shape) is not None:
                if tensorshape_util.num_elements(outcomes.shape) == 0:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    tf.assert_greater(tf.size(outcomes), 0, message=message))

            if logits is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    # pylint: disable=protected-access
                    self._categorical._logits,
                    # pylint: enable=protected-access
                    message=
                    'Last dimension of outcomes and logits must be equal size.'
                )
                if maybe_assert:
                    assertions.append(maybe_assert)

            if probs is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    probs,
                    message=
                    'Last dimension of outcomes and probs must be equal size.')
                if maybe_assert:
                    assertions.append(maybe_assert)

            message = 'Rank of outcomes must be 1.'
            ndims = tensorshape_util.rank(outcomes.shape)
            if ndims is not None:
                if ndims != 1:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    assert_util.assert_rank(outcomes, 1, message=message))

        if not validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(outcomes):
            assertions.append(
                assert_util.assert_equal(
                    tf.math.is_strictly_increasing(outcomes),
                    True,
                    message='outcomes is not strictly increasing.'))

        return assertions