Пример #1
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     peak = None
     if is_init != tensor_util.is_ref(self.low):
         peak = tf.convert_to_tensor(self.peak)
         assertions.append(
             assert_util.assert_greater(
                 peak,
                 self.low,
                 message='`peak` must be greater than `low`.'))
     if is_init != tensor_util.is_ref(self.high):
         peak = tf.convert_to_tensor(self.peak) if peak is None else peak
         assertions.append(
             assert_util.assert_greater(
                 self.high,
                 peak,
                 message='`high` must be greater than `peak`.'))
     if is_init != tensor_util.is_ref(self.temperature):
         assertions.append(
             assert_util.assert_positive(
                 self.temperature,
                 message='`temperature` must be positive.'))
     return assertions
Пример #2
0
def _maybe_check_valid_map_values(map_values, validate_args):
    """Validate `map_values` if `validate_args`==True."""
    assertions = []

    message = 'Rank of map_values must be 1.'
    if tensorshape_util.rank(map_values.shape) is not None:
        if tensorshape_util.rank(map_values.shape) != 1:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(map_values, 1, message=message))

    message = 'Size of map_values must be greater than 0.'
    if tensorshape_util.num_elements(map_values.shape) is not None:
        if tensorshape_util.num_elements(map_values.shape) == 0:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(tf.size(map_values), 0,
                                       message=message))

    if validate_args:
        assertions.append(
            assert_util.assert_equal(
                tf.math.is_strictly_increasing(map_values),
                True,
                message='map_values is not strictly increasing.'))

    return assertions
Пример #3
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []
    mean_direction = tf.convert_to_tensor(self.mean_direction)
    concentration = tf.convert_to_tensor(self.concentration)

    assertions = []
    if is_init != tensor_util.is_ref(self._mean_direction):
      assertions.append(
          assert_util.assert_greater(
              tf.shape(mean_direction)[-1],
              1,
              message='`mean_direction` may not have scalar event shape'))
      assertions.append(
          assert_util.assert_less_equal(
              tf.shape(mean_direction)[-1],
              5,
              message='von Mises-Fisher ndims > 5 is not currently supported'))
      assertions.append(
          assert_util.assert_near(
              1.,
              tf.linalg.norm(mean_direction, axis=-1),
              message='`mean_direction` must be unit-length'))
    if is_init != tensor_util.is_ref(self._concentration):
      assertions.append(
          assert_util.assert_non_negative(
              concentration, message='`concentration` must be non-negative'))
    return assertions
Пример #4
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   low = None
   high = None
   if is_init != tensor_util.is_ref(self.low):
     low = tf.convert_to_tensor(self.low)
     assertions.append(
         assert_util.assert_finite(low, message='`low` is not finite'))
   if is_init != tensor_util.is_ref(self.high):
     high = tf.convert_to_tensor(self.high)
     assertions.append(
         assert_util.assert_finite(high, message='`high` is not finite'))
   if is_init != tensor_util.is_ref(self.loc):
     assertions.append(
         assert_util.assert_finite(self.loc, message='`loc` is not finite'))
   if is_init != tensor_util.is_ref(self.scale):
     scale = tf.convert_to_tensor(self.scale)
     assertions.extend([
         assert_util.assert_positive(
             scale, message='`scale` must be positive'),
         assert_util.assert_finite(scale, message='`scale` is not finite'),
     ])
   if (is_init != tensor_util.is_ref(self.low) or
       is_init != tensor_util.is_ref(self.high)):
     low = tf.convert_to_tensor(self.low) if low is None else low
     high = tf.convert_to_tensor(self.high) if high is None else high
     assertions.append(
         assert_util.assert_greater(
             high,
             low,
             message='TruncatedCauchy not defined when `low >= high`.'))
   return assertions
Пример #5
0
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []
        mean_direction = tf.convert_to_tensor(self.mean_direction)
        concentration = tf.convert_to_tensor(self.concentration)

        assertions = []
        if is_init != tensor_util.is_ref(self._mean_direction):
            assertions.append(
                assert_util.assert_greater(
                    tf.shape(mean_direction)[-1],
                    1,
                    message=
                    '`mean_direction` must be a vector of at least size 2.'))
            assertions.append(
                assert_util.assert_near(
                    tf.cast(1., self.dtype),
                    tf.linalg.norm(mean_direction, axis=-1),
                    message='`mean_direction` must be unit-length'))
        if is_init != tensor_util.is_ref(self._concentration):
            assertions.append(
                assert_util.assert_non_negative(
                    concentration,
                    message='`concentration` must be non-negative'))
        return assertions
Пример #6
0
    def _maybe_validate_split_sizes(self):
        """Validations for `split_sizes` property."""
        assertions = []
        split_sizes = tf.convert_to_tensor(self.split_sizes)
        split_sizes_ = tf.get_static_value(split_sizes)

        # Ensure `split_sizes` has no more than one unknown split size (=-1).
        message = '`{}` elements must have at most one `-1`.'
        if split_sizes_ is not None:
            if sum(split_sizes_ == -1) > 1:
                raise ValueError(message.format(split_sizes))
        elif self.validate_args:
            assertions.append(
                assert_util.assert_less(tf.reduce_sum(
                    tf.cast(tf.equal(split_sizes, -1), tf.int32)),
                                        2,
                                        message=message.format(split_sizes)))

        message = '`{}` elements must be either non-negative integers or `-1`.'
        if split_sizes_ is not None:
            if np.any(split_sizes_ < -1):
                raise ValueError(message.format(split_sizes))
        elif self.validate_args:
            assertions.append(
                assert_util.assert_greater(
                    split_sizes, -2, message=message.format(split_sizes)))

        return assertions
Пример #7
0
 def _inverse_event_shape_tensor(self, output_shape):
   if self.validate_args:
     # It is not possible for a negative shape so we need only check <= 1.
     is_greater_one = assert_util.assert_greater(
         output_shape[-1], 1, message="Need last dimension greater than 1.")
     output_shape = distribution_util.with_dependencies(
         [is_greater_one], output_shape)
   return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)
Пример #8
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self.power):
     assertions.append(assert_util.assert_greater(
         self.power, np.ones([], dtype_util.as_numpy_dtype(self.power.dtype)),
         message='`power` must be greater than 1.'))
   return assertions
Пример #9
0
 def _inverse_event_shape_tensor(self, output_shape):
   if self.validate_args:
     # It is not possible for a negative shape so we need only check <= 1.
     dependencies = [assert_util.assert_greater(
         output_shape[-1], 1, message="Need last dimension greater than 1.")]
   else:
     dependencies = []
   with tf.control_dependencies(dependencies):
     return tf.concat([output_shape[:-1], [output_shape[-1] - 1]], axis=0)
Пример #10
0
  def _log_prob(self, x):
    with tf.control_dependencies([
        assert_util.assert_greater(
            x, tf.cast(0., x.dtype.base_dtype), message="x must be positive.")
    ] if self.validate_args else []):

      return (0.5 * (tf.math.log(self.concentration) - np.log(2. * np.pi) -
                     3. * tf.math.log(x)) + (-self.concentration *
                                             (x - self.loc)**2.) /
              (2. * self.loc**2. * x))
Пример #11
0
 def _assertions(self, t):
     if not self.validate_args:
         return []
     return [
         assert_util.assert_greater(
             t[..., 1:],
             t[..., :-1],
             message=
             'Inverse transformation input must be strictly increasing.')
     ]
Пример #12
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self.df):
     assertions.append(
         assert_util.assert_greater(
             self.df, dtype_util.as_numpy_dtype(self.df.dtype)(2.),
             message='`df` must be greater than 2.'))
   return assertions
Пример #13
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args or self.low is None or self.high is None:
     return []
   assertions = []
   if is_init != (tensor_util.is_ref(self.low) or
                  tensor_util.is_ref(self.high)):
     assertions.append(assert_util.assert_greater(
         self.high, self.low,
         message='Argument `high` must be greater than `low`.'))
   return assertions
Пример #14
0
 def _assert_valid_inverse_input(self, y):
   assertions = []
   if self.validate_args and self.low is not None:
     assertions += [assert_util.assert_greater(
         y, self.low,
         message='Input must be greater than `low`.')]
   if self.validate_args and self.high is not None:
     assertions += [assert_util.assert_less(
         y, self.high,
         message='Input must be less than `high`.')]
   return assertions
Пример #15
0
    def _cdf(self, x):
        with tf.control_dependencies([
                assert_util.assert_greater(x,
                                           tf.cast(0., x.dtype.base_dtype),
                                           message="x must be positive.")
        ] if self.validate_args else []):

            return (special_math.ndtr(
                ((self.concentration / x)**0.5 * (x / self.loc - 1.))) +
                    tf.exp(2. * self.concentration / self.loc) *
                    special_math.ndtr(-(self.concentration / x)**0.5 *
                                      (x / self.loc + 1)))
Пример #16
0
 def _inverse_event_shape_tensor(self, output_shape):
   input_shape = self._forward_event_shape_tensor(
       output_shape, is_inverse=True)
   if not self.validate_args:
     return input_shape
   assertions = [
       assert_util.assert_greater(
           input_shape, -1,
           message='Invalid inverse shape; found negative size.')
   ]
   with tf.control_dependencies(assertions):
     return tf.identity(input_shape)
Пример #17
0
 def _assertions(self, t):
     if not self.validate_args:
         return []
     return [
         assert_util.assert_greater(
             t,
             dtype_util.as_numpy_dtype(t.dtype)(-1),
             message="Inverse transformation input must be greater than -1."
         ),
         assert_util.assert_less(
             t,
             dtype_util.as_numpy_dtype(t.dtype)(1),
             message="Inverse transformation input must be less than 1.")
     ]
Пример #18
0
    def _parameter_control_dependencies(self, is_init):
        if is_init:
            # Check that locs and slopes have the same last dimension.
            if (self.locs.shape is not None and self.locs.shape[-1] is not None
                    and self.slopes.shape is not None
                    and self.slopes.shape[-1] is not None):
                k = self.locs.shape[-1]
                l = self.slopes.shape[-1]
                if not (k == l or k == 1 or l == 1):
                    raise ValueError(
                        'Expect that `locs` and `slopes` are broadcastable.')

            if self.locs.shape is not None and self.locs.shape[-1] is not None:
                if not (len(self.kernels) == self.locs.shape[-1] + 1
                        or self.locs.shape[-1] == 1):
                    raise ValueError(
                        'Expect that `locs` has last dimension `1` or `N - 1` where '
                        f'`N` is the number of kernels, but got {self.locs.shape[-1]}'
                    )

            if self.slopes.shape is not None and self.slopes.shape[
                    -1] is not None:
                if not (len(self.kernels) == self.slopes.shape[-1] + 1
                        or self.slopes.shape[-1] == 1):
                    raise ValueError(
                        'Expect that `slopes` has last dimension `1` or `N - 1` where '
                        f'`N` is the number of kernels, but got {self.slopes.shape[-1]}'
                    )

        assertions = []
        if not self.validate_args:
            return assertions

        if is_init != tensor_util.is_ref(self.locs):
            locs = tf.convert_to_tensor(self.locs)

            assertions.append(
                assert_util.assert_greater(
                    locs[..., 1:],
                    locs[..., :-1],
                    message='Expect that elements of `locs` are ascending.'))

        if is_init != tensor_util.is_ref(self.slopes):
            slopes = tf.convert_to_tensor(self.slopes)
            assertions.append(
                assert_util.assert_positive(
                    slopes, message='`slopes` must be positive.'))

        return assertions
Пример #19
0
def _maybe_validate_target_accept_prob(target_accept_prob, validate_args):
    """Validates that target_accept_prob is in (0, 1)."""
    if not validate_args:
        return target_accept_prob
    assertions = [
        assert_util.assert_greater(
            target_accept_prob,
            tf.zeros([], dtype=target_accept_prob.dtype),
            message='`target_accept_prob` must be > 0.'),
        assert_util.assert_less(target_accept_prob,
                                tf.ones([], dtype=target_accept_prob.dtype),
                                message='`target_accept_prob` must be < 1.')
    ]
    with tf.control_dependencies(assertions):
        return tf.identity(target_accept_prob)
Пример #20
0
    def _maybe_assert_valid_y(self, y):
        if not self.validate_args:
            return y
        is_valid = [
            assert_util.assert_greater(
                y,
                dtype_util.as_numpy_dtype(y.dtype)(-1),
                message="Inverse transformation input must be greater than -1."
            ),
            assert_util.assert_less(
                y,
                dtype_util.as_numpy_dtype(y.dtype)(1),
                message="Inverse transformation input must be less than 1.")
        ]

        return distribution_util.with_dependencies(is_valid, y)
Пример #21
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []
        sample_shape = None  # Memoize concretization.

        # Check valid shape.
        ndims_ = tensorshape_util.rank(self.sample_shape.shape)
        if is_init != (ndims_ is None):
            msg = 'Argument `sample_shape` must be either a scalar or a vector.'
            if ndims_ is not None:
                if ndims_ > 1:
                    raise ValueError(msg)
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_less(tf.rank(sample_shape),
                                            2,
                                            message=msg))

        # Check valid dtype.
        if is_init:  # No xor check because `dtype` cannot change.
            dtype_ = self.sample_shape.dtype
            if dtype_ is None:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                dtype_ = sample_shape.dtype
            if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}:
                raise TypeError(
                    'Argument `sample_shape` must be integer type; '
                    'saw {}.'.format(dtype_util.name(dtype_)))

        # Check valid "value".
        if is_init != tensor_util.is_ref(self.sample_shape):
            sample_shape_ = tf.get_static_value(self.sample_shape)
            msg = 'Argument `sample_shape` must have non-negative values.'
            if sample_shape_ is not None:
                if np.any(np.array(sample_shape_) < 0):
                    raise ValueError('{} Saw: {}'.format(msg, sample_shape_))
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_greater(sample_shape, -1, message=msg))

        return assertions
Пример #22
0
def _maybe_check_valid_shape(shape, validate_args):
    """Check that a shape Tensor is int-type and otherwise sane."""
    if not dtype_util.is_integer(shape.dtype):
        raise TypeError('`{}` dtype (`{}`) should be `int`-like.'.format(
            shape, dtype_util.name(shape.dtype)))

    assertions = []

    message = '`{}` rank should be <= 1.'
    if tensorshape_util.rank(shape.shape) is not None:
        if tensorshape_util.rank(shape.shape) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.rank(shape),
                                    2,
                                    message=message.format(shape)))

    shape_ = tf.get_static_value(shape)

    message = '`{}` elements must have at most one `-1`.'
    if shape_ is not None:
        if sum(shape_ == -1) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.reduce_sum(
                tf.cast(tf.equal(shape, -1), tf.int32)),
                                    2,
                                    message=message.format(shape)))

    message = '`{}` elements must be either positive integers or `-1`.'
    if shape_ is not None:
        if np.any(shape_ < -1):
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(shape,
                                       -2,
                                       message=message.format(shape)))

    return assertions
Пример #23
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self.mass):
         assertions.append(
             assert_util.assert_positive(
                 self.mass, message='Argument `mass` must be positive.'))
     if is_init != tensor_util.is_ref(self.width):
         assertions.append(
             assert_util.assert_positive(
                 self.width, message='Argument `width` must be positive.'))
     if is_init != tensor_util.is_ref(self.smin):
         assertions.append(
             assert_util.assert_non_negative(
                 self.smin,
                 message='Argument `smin` must be positive or zero.'))
     if is_init != tensor_util.is_ref(self.smax):
         assertions.append(
             assert_util.assert_greater(
                 self.smax,
                 self.smin,
                 message='Argument `smax` must be larger than `smin`.'))
     return assertions
Пример #24
0
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (tensorshape_util.dims(transition_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (tensorshape_util.dims(observation_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (tensorshape_util.dims(transition_distribution.batch_shape)
                     is not None and tensorshape_util.as_list(
                         transition_distribution.batch_shape)[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (tensorshape_util.dims(
                        observation_distribution.batch_shape) is not None
                     and tensorshape_util.as_list(
                         observation_distribution.batch_shape)[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
Пример #25
0
def _potential_scale_reduction_single_state(state, independent_chain_ndims,
                                            split_chains, validate_args):
    """potential_scale_reduction for one single state `Tensor`."""
    with tf.name_scope('potential_scale_reduction_single_state'):
        # We assume exactly one leading dimension indexes e.g. correlated samples
        # from each Markov chain.
        state = tf.convert_to_tensor(state, name='state')

        n_samples_ = tf.compat.dimension_value(state.shape[0])
        if n_samples_ is not None:  # If available statically.
            if split_chains and n_samples_ < 4:
                raise ValueError(
                    'Must provide at least 4 samples when splitting chains. '
                    'Found {}'.format(n_samples_))
            if not split_chains and n_samples_ < 2:
                raise ValueError(
                    'Must provide at least 2 samples.  Found {}'.format(
                        n_samples_))
        elif validate_args:
            if split_chains:
                assertions = [
                    assert_util.assert_greater(
                        tf.shape(state)[0],
                        4,
                        message=
                        'Must provide at least 4 samples when splitting chains.'
                    )
                ]
                with tf.control_dependencies(assertions):
                    state = tf.identity(state)
            else:
                assertions = [
                    assert_util.assert_greater(
                        tf.shape(state)[0],
                        2,
                        message='Must provide at least 2 samples.')
                ]
                with tf.control_dependencies(assertions):
                    state = tf.identity(state)

        # Define so it's not a magic number.
        # Warning!  `if split_chains` logic assumes this is 1!
        sample_ndims = 1

        if split_chains:
            # Split the sample dimension in half, doubling the number of
            # independent chains.

            # For odd number of samples, keep all but the last sample.
            state_shape = prefer_static.shape(state)
            n_samples = state_shape[0]
            state = state[:n_samples - n_samples % 2]

            # Suppose state = [0, 1, 2, 3, 4, 5]
            # Step 1: reshape into [[0, 1, 2], [3, 4, 5]]
            # E.g. reshape states of shape [a, b] into [2, a//2, b].
            state = tf.reshape(
                state,
                prefer_static.concat([[2, n_samples // 2], state_shape[1:]],
                                     axis=0))
            # Step 2: Put the size `2` dimension in the right place to be treated as a
            # chain, changing [[0, 1, 2], [3, 4, 5]] into [[0, 3], [1, 4], [2, 5]],
            # reshaping [2, a//2, b] into [a//2, 2, b].
            state = tf.transpose(
                a=state,
                perm=prefer_static.concat(
                    [[1, 0], tf.range(2, tf.rank(state))], axis=0))

            # We're treating the new dim as indexing 2 chains, so increment.
            independent_chain_ndims += 1

        sample_axis = tf.range(0, sample_ndims)
        chain_axis = tf.range(sample_ndims,
                              sample_ndims + independent_chain_ndims)
        sample_and_chain_axis = tf.range(
            0, sample_ndims + independent_chain_ndims)

        n = _axis_size(state, sample_axis)
        m = _axis_size(state, chain_axis)

        # In the language of Brooks and Gelman (1998),
        # B / n is the between chain variance, the variance of the chain means.
        # W is the within sequence variance, the mean of the chain variances.
        b_div_n = _reduce_variance(tf.reduce_mean(state,
                                                  axis=sample_axis,
                                                  keepdims=True),
                                   sample_and_chain_axis,
                                   biased=False)
        w = tf.reduce_mean(_reduce_variance(state,
                                            sample_axis,
                                            keepdims=True,
                                            biased=True),
                           axis=sample_and_chain_axis)

        # sigma^2_+ is an estimate of the true variance, which would be unbiased if
        # each chain was drawn from the target.  c.f. "law of total variance."
        sigma_2_plus = w + b_div_n

        return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n)
Пример #26
0
def _effective_sample_size_single_state(states, filter_beyond_lag,
                                        filter_threshold,
                                        filter_beyond_positive_pairs,
                                        cross_chain_dims, validate_args):
    """ESS computation for one single Tensor argument."""

    with tf.name_scope('effective_sample_size_single_state'):

        states = tf.convert_to_tensor(states, name='states')
        dt = states.dtype

        # filter_beyond_lag == None ==> auto_corr is the full sequence.
        auto_cov = stats.auto_correlation(states,
                                          axis=0,
                                          max_lags=filter_beyond_lag,
                                          normalize=False)
        n = _axis_size(states, axis=0)

        if cross_chain_dims is not None:
            num_chains = _axis_size(states, cross_chain_dims)
            num_chains_ = tf.get_static_value(num_chains)

            assertions = []
            msg = (
                'When `cross_chain_dims` is not `None`, there must be > 1 chain '
                'in `states`.')
            if num_chains_ is not None:
                if num_chains_ < 2:
                    raise ValueError(msg)
            elif validate_args:
                assertions.append(
                    assert_util.assert_greater(num_chains, 1., message=msg))

            with tf.control_dependencies(assertions):
                # We're computing the R[k] from equation 10 of Vehtari et al.
                # (2019):
                #
                # R[k] := 1 - (W - 1/C * Sum_{c=1}^C s_c**2 R[k, c]) / (var^+),
                #
                # where:
                #   C := number of chains
                #   N := length of chains
                #   x_hat[c] := 1 / N Sum_{n=1}^N x[n, c], chain mean.
                #   x_hat := 1 / C Sum_{c=1}^C x_hat[c], overall mean.
                #   W := 1/C Sum_{c=1}^C s_c**2, within-chain variance.
                #   B := N / (C - 1) Sum_{c=1}^C (x_hat[c] - x_hat)**2, between chain
                #     variance.
                #   s_c**2 := 1 / (N - 1) Sum_{n=1}^N (x[n, c] - x_hat[c])**2, chain
                #       variance
                #   R[k, m] := auto_corr[k, m, ...], auto-correlation indexed by chain.
                #   var^+ := (N - 1) / N * W + B / N

                cross_chain_dims = prefer_static.non_negative_axis(
                    cross_chain_dims, prefer_static.rank(states))
                # B / N
                between_chain_variance_div_n = _reduce_variance(
                    tf.reduce_mean(states, axis=0),
                    biased=False,  # This makes the denominator be C - 1.
                    axis=cross_chain_dims - 1)
                # W * (N - 1) / N
                biased_within_chain_variance = tf.reduce_mean(
                    auto_cov[0], cross_chain_dims - 1)
                # var^+
                approx_variance = (biased_within_chain_variance +
                                   between_chain_variance_div_n)
                # 1/C * Sum_{c=1}^C s_c**2 R[k, c]
                mean_auto_cov = tf.reduce_mean(auto_cov, cross_chain_dims)
                auto_corr = 1. - (biased_within_chain_variance -
                                  mean_auto_cov) / approx_variance
        else:
            auto_corr = auto_cov / auto_cov[:1]
            num_chains = 1

        # With R[k] := auto_corr[k, ...],
        # ESS = N / {1 + 2 * Sum_{k=1}^N R[k] * (N - k) / N}
        #     = N / {-1 + 2 * Sum_{k=0}^N R[k] * (N - k) / N} (since R[0] = 1)
        #     approx N / {-1 + 2 * Sum_{k=0}^M R[k] * (N - k) / N}
        # where M is the filter_beyond_lag truncation point chosen above.

        # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total
        # ndims the same as auto_corr
        k = tf.range(0., _axis_size(auto_corr, axis=0))
        nk_factor = (n - k) / n
        if tensorshape_util.rank(auto_corr.shape) is not None:
            new_shape = [-1
                         ] + [1] * (tensorshape_util.rank(auto_corr.shape) - 1)
        else:
            new_shape = tf.concat(
                ([-1], tf.ones([tf.rank(auto_corr) - 1], dtype=tf.int32)),
                axis=0)
        nk_factor = tf.reshape(nk_factor, new_shape)
        weighted_auto_corr = nk_factor * auto_corr

        if filter_beyond_positive_pairs:

            def _sum_pairs(x):
                x_len = tf.shape(x)[0]
                # For odd sequences, we drop the final value.
                x = x[:x_len - x_len % 2]
                new_shape = tf.concat(
                    [[x_len // 2, 2], tf.shape(x)[1:]], axis=0)
                return tf.reduce_sum(tf.reshape(x, new_shape), 1)

            # Pairwise sums are all positive for auto-correlation spectra derived from
            # reversible MCMC chains.
            # E.g. imagine the pairwise sums are [0.2, 0.1, -0.1, -0.2]
            # Step 1: mask = [False, False, True, True]
            mask = _sum_pairs(auto_corr) < 0.
            # Step 2: mask = [0, 0, 1, 1]
            mask = tf.cast(mask, dt)
            # Step 3: mask = [0, 0, 1, 2]
            mask = tf.cumsum(mask, axis=0)
            # Step 4: mask = [1, 1, 0, 0]
            mask = tf.maximum(1. - mask, 0.)

            # N.B. this reduces the length of weighted_auto_corr by a factor of 2.
            # It still works fine in the formula below.
            weighted_auto_corr = _sum_pairs(weighted_auto_corr) * mask
        elif filter_threshold is not None:
            filter_threshold = tf.convert_to_tensor(filter_threshold,
                                                    dtype=dt,
                                                    name='filter_threshold')
            # Get a binary mask to zero out values of auto_corr below the threshold.
            #   mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i,
            #   mask[i, ...] = 0, otherwise.
            # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...]
            # Building step by step,
            #   Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2.
            # Step 1:  mask = [False, False, True, False]
            mask = auto_corr < filter_threshold
            # Step 2:  mask = [0, 0, 1, 0]
            mask = tf.cast(mask, dtype=dt)
            # Step 3:  mask = [0, 0, 1, 1]
            mask = tf.cumsum(mask, axis=0)
            # Step 4:  mask = [1, 1, 0, 0]
            mask = tf.maximum(1. - mask, 0.)
            weighted_auto_corr *= mask

        return num_chains * n / (-1 +
                                 2 * tf.reduce_sum(weighted_auto_corr, axis=0))
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    # Check num_steps is a scalar that's at least 1.
    if is_init != tensor_util.is_ref(self.num_steps):
      num_steps = tf.convert_to_tensor(self.num_steps)
      num_steps_ = tf.get_static_value(num_steps)
      if num_steps_ is not None:
        if np.ndim(num_steps_) != 0:
          raise ValueError(
              '`num_steps` must be a scalar but it has rank {}'.format(
                  np.ndim(num_steps_)))
        if num_steps_ < 1:
          raise ValueError('`num_steps` must be at least 1.')
      elif self.validate_args:
        message = '`num_steps` must be a scalar'
        assertions.append(
            assert_util.assert_rank_at_most(self.num_steps, 0, message=message))
        assertions.append(
            assert_util.assert_greater_equal(
                num_steps, 1,
                message='`num_steps` must be at least 1.'))

    # Check that the initial distribution has scalar events over the
    # integers.
    if is_init and not dtype_util.is_integer(self.initial_distribution.dtype):
      raise ValueError(
          '`initial_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.initial_distribution.dtype)))

    if tensorshape_util.rank(self.initial_distribution.event_shape) is not None:
      if tensorshape_util.rank(self.initial_distribution.event_shape) != 0:
        raise ValueError('`initial_distribution` must have scalar `event_dim`s')
    elif self.validate_args:
      assertions += [
          assert_util.assert_equal(
              ps.size(self.initial_distribution.event_shape_tensor()),
              0,
              message='`initial_distribution` must have scalar `event_dim`s'),
      ]

    # Check that the transition distribution is over the integers.
    if (is_init and
        not dtype_util.is_integer(self.transition_distribution.dtype)):
      raise ValueError(
          '`transition_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.transition_distribution.dtype)))

    # Check observations have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.observation_distribution.batch_shape) == 0:
      raise ValueError(
          "`observation_distribution` can't have scalar batches")

    # Check transitions have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.transition_distribution.batch_shape) == 0:
      raise ValueError(
          "`transition_distribution` can't have scalar batches")

    # Check compatibility of transition distribution and observation
    # distribution.
    tdbs = self.transition_distribution.batch_shape
    odbs = self.observation_distribution.batch_shape
    if (tensorshape_util.dims(tdbs) is not None and
        tf.compat.dimension_value(odbs[-1]) is not None):
      if (tf.compat.dimension_value(tdbs[-1]) !=
          tf.compat.dimension_value(odbs[-1])):
        raise ValueError(
            '`transition_distribution` and `observation_distribution` '
            'must agree on last dimension of batch size')
    elif self.validate_args:
      tdbs = self.transition_distribution.batch_shape_tensor()
      odbs = self.observation_distribution.batch_shape_tensor()
      transition_precondition = assert_util.assert_greater(
          ps.size(tdbs), 0,
          message=('`transition_distribution` can\'t have scalar '
                   'batches'))
      observation_precondition = assert_util.assert_greater(
          ps.size(odbs), 0,
          message=('`observation_distribution` can\'t have scalar '
                   'batches'))
      with tf.control_dependencies([
          transition_precondition,
          observation_precondition]):
        assertions += [
            assert_util.assert_equal(
                tdbs[-1],
                odbs[-1],
                message=('`transition_distribution` and '
                         '`observation_distribution` '
                         'must agree on last dimension of batch size'))]

    return assertions
Пример #28
0
  def __init__(self,
               mean_direction,
               concentration,
               validate_args=False,
               allow_nan_stats=True,
               name='VonMisesFisher'):
    """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. (This is *not* in general the
        mean of the distribution; the mean is not generally in the support of
        the distribution.) NOTE: `D` is currently restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: For known-bad arguments, i.e. unsupported event dimension.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([mean_direction, concentration],
                                      tf.float32)
      mean_direction = tf.convert_to_tensor(
          mean_direction, name='mean_direction', dtype=dtype)
      concentration = tf.convert_to_tensor(
          concentration, name='concentration', dtype=dtype)
      assertions = [
          assert_util.assert_non_negative(
              concentration, message='`concentration` must be non-negative'),
          assert_util.assert_greater(
              tf.shape(mean_direction)[-1],
              1,
              message='`mean_direction` may not have scalar event shape'),
          assert_util.assert_near(
              1.,
              tf.linalg.norm(mean_direction, axis=-1),
              message='`mean_direction` must be unit-length')
      ] if validate_args else []
      static_event_dim = tf.compat.dimension_value(
          tensorshape_util.with_rank_at_least(mean_direction.shape, 1)[-1])
      if static_event_dim is not None and static_event_dim > 5:
        raise ValueError('vMF ndims > 5 is not currently supported')
      elif validate_args:
        assertions += [
            assert_util.assert_less_equal(
                tf.shape(mean_direction)[-1],
                5,
                message='vMF ndims > 5 is not currently supported')
        ]
      with tf.control_dependencies(assertions):
        self._mean_direction = tf.identity(mean_direction)
        self._concentration = tf.identity(concentration)
      dtype_util.assert_same_float_dtype(
          [self._mean_direction, self._concentration])
      # mean_direction is always reparameterized.
      # concentration is only for event_dim==3, via an inversion sampler.
      reparameterization_type = (
          reparameterization.FULLY_REPARAMETERIZED
          if static_event_dim == 3 else
          reparameterization.NOT_REPARAMETERIZED)
      super(VonMisesFisher, self).__init__(
          dtype=self._concentration.dtype,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          reparameterization_type=reparameterization_type,
          parameters=parameters,
          name=name)
Пример #29
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    axis = None
    paddings = None

    if is_init != tensor_util.is_ref(self.axis):
      # First we check the shape of the axis argument.
      msg = 'Argument `axis` must be scalar or vector.'
      if tensorshape_util.rank(self.axis.shape) is not None:
        if tensorshape_util.rank(self.axis.shape) > 1:
          raise ValueError(msg)
      elif self.validate_args:
        if axis is None: axis = tf.convert_to_tensor(self.axis)
        assertions.append(assert_util.assert_rank_at_most(
            axis, 1, message=msg))
      # Next we check the values of the axis argument.
      axis_ = tf.get_static_value(self.axis)
      msg = 'Argument `axis` must be negative.'
      if axis_ is not None:
        if np.any(axis_ > -1):
          raise ValueError(msg)
      elif self.validate_args:
        if axis is None: axis = tf.convert_to_tensor(self.axis)
        assertions.append(assert_util.assert_less(axis, 0, message=msg))
      msg = 'Argument `axis` elements must be unique.'
      if axis_ is not None:
        if len(np.array(axis_).reshape(-1)) != len(np.unique(axis_)):
          raise ValueError(msg)
      elif self.validate_args:
        if axis is None: axis = tf.convert_to_tensor(self.axis)
        assertions.append(assert_util.assert_equal(
            prefer_static.size0(axis),
            prefer_static.size0(prefer_static.setdiff1d(axis)),
            message=msg))

    if is_init != tensor_util.is_ref(self.paddings):
      # First we check the shape of the paddings argument.
      msg = 'Argument `paddings` must be a vector of pairs.'
      if tensorshape_util.is_fully_defined(self.paddings.shape):
        shape = np.int32(self.paddings.shape)
        if len(shape) != 2 or shape[0] < 1 or shape[1] != 2:
          raise ValueError(msg)
      elif self.validate_args:
        if paddings is None: paddings = tf.convert_to_tensor(self.paddings)
        with tf.control_dependencies([
            assert_util.assert_equal(tf.rank(paddings), 2, message=msg)]):
          shape = tf.shape(paddings)
          assertions.extend([
              assert_util.assert_greater(shape[0], 0, message=msg),
              assert_util.assert_equal(shape[1], 2, message=msg),
          ])
      # Next we check the values of the paddings argument.
      paddings_ = tf.get_static_value(self.paddings)
      msg = 'Argument `paddings` must be non-negative.'
      if paddings_ is not None:
        if np.any(paddings_ < 0):
          raise ValueError(msg)
      elif self.validate_args:
        if paddings is None: paddings = tf.convert_to_tensor(self.paddings)
        assertions.append(assert_util.assert_greater(
            paddings, -1, message=msg))

    if is_init != (tensor_util.is_ref(self.axis) and
                   tensor_util.is_ref(self.paddings)):
      axis_ = tf.get_static_value(self.axis)
      if axis_ is None and axis is None:
        axis = tf.convert_to_tensor(self.axis)
      len_axis = prefer_static.size0(prefer_static.reshape(
          axis if axis_ is None else axis_, shape=-1))

      paddings_ = tf.get_static_value(self.paddings)
      if paddings_ is None and paddings is None:
        paddings = tf.convert_to_tensor(self.paddings)
      len_paddings = prefer_static.size0(
          paddings if paddings_ is None else paddings_)

      msg = ('Arguments `axis` and `paddings` must have the same number '
             'of elements.')
      if (prefer_static.is_numpy(len_axis) and
          prefer_static.is_numpy(len_paddings)):
        if len_axis != len_paddings:
          raise ValueError(msg + ' Saw: {}, {}.'.format(
              self.axis, self.paddings))
      elif self.validate_args:
        assertions.append(assert_util.assert_equal(
            len_axis, len_paddings, message=msg))

    return assertions
Пример #30
0
def kendalls_tau(y_true, y_pred, name=None):
    """Computes Kendall's Tau for two ordered lists.

  Kendall's Tau measures the correlation between ordinal rankings. This
  implementation is similar to the one used in scipy.stats.kendalltau.
  The provided values may be of any type that is sortable, with the
  argsort indices indicating the true or proposed ordinal sequence.

  Args:
    y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking.
    y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the
      same N items.
    name: Optional Python `str` name for ops created by this method.
      Default value: `None` (i.e., 'kendalls_tau').

  Returns:
    kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores
      ordering of ties, as a `float32` scalar Tensor.
  """
    with tf.name_scope(name or 'kendalls_tau'):
        in_type = dtype_util.common_dtype([y_true, y_pred],
                                          dtype_hint=tf.float32)
        y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type)
        y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type)
        tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape)
        assertions = [
            assert_util.assert_rank(y_true, 1),
            assert_util.assert_greater(
                ps.size(y_true), 1, 'Ordering requires at least 2 elements.')
        ]
        with tf.control_dependencies(assertions):
            lexa = lexicographical_indirect_sort(y_true, y_pred)

        # See A Computer Method for Calculating Kendall's Tau with Ungrouped Data
        # by William Night, Journal of the American Statistical Association,
        # Jun., 1966, Vol. 61, No. 314, Part 1 (Jun., 1966), pp. 436-439
        # for notation https://www.jstor.org/stable/2282833

        def jointly_tied_pairs_body(first, t, i):
            not_equal = tf.math.logical_or(
                tf.not_equal(y_true[lexa[first]], y_true[lexa[i]]),
                tf.not_equal(y_pred[lexa[first]], y_pred[lexa[i]]))
            return (tf.where(not_equal, i, first),
                    tf.where(not_equal,
                             t + ((i - first) * (i - first - 1)) // 2,
                             t), i + 1)

        n = ps.size0(y_true)
        first, t, _ = tf.while_loop(cond=lambda first, t, i: i < n,
                                    body=jointly_tied_pairs_body,
                                    loop_vars=(0, 0, 1))
        t += ((n - first) * (n - first - 1)) // 2

        def ties_y_true_body(first, v, i):
            not_equal = tf.not_equal(y_true[lexa[first]], y_true[lexa[i]])
            return (tf.where(not_equal, i, first),
                    tf.where(not_equal,
                             v + ((i - first) * (i - first - 1)) // 2,
                             v), i + 1)

        first, v, _ = tf.while_loop(cond=lambda first, v, i: i < n,
                                    body=ties_y_true_body,
                                    loop_vars=(0, 0, 1))
        v += ((n - first) * (n - first - 1)) // 2

        # count exchanges
        exchanges, newperm = iterative_mergesort(y_pred, lexa)

        def ties_in_y_pred_body(first, u, i):
            not_equal = tf.not_equal(y_pred[newperm[first]],
                                     y_pred[newperm[i]])
            return (tf.where(not_equal, i, first),
                    tf.where(not_equal,
                             u + ((i - first) * (i - first - 1)) // 2,
                             u), i + 1)

        first, u, _ = tf.while_loop(cond=lambda first, u, i: i < n,
                                    body=ties_in_y_pred_body,
                                    loop_vars=(0, 0, 1))
        u += ((n - first) * (n - first - 1)) // 2
        n0 = (n * (n - 1)) // 2
        assertions = [
            assert_util.assert_less(v, tf.cast(n0, tf.int32),
                                    'All ranks are ties for y_true.'),
            assert_util.assert_less(u, tf.cast(n0, tf.int32),
                                    'All ranks are ties for y_pred.')
        ]
        with tf.control_dependencies(assertions):
            return (tf.cast(n0 - (u + v - t), tf.float32) -
                    2.0 * tf.cast(exchanges, tf.float32)) / tf.math.sqrt(
                        tf.cast(n0 - v, tf.float32) *
                        tf.cast(n0 - u, tf.float32))