Exemplo n.º 1
0
    def __init__(self, permutation, axis=-1, validate_args=False, name=None):
        """Creates the `Permute` bijector.

    Args:
      permutation: An `int`-like vector-shaped `Tensor` representing the
        permutation to apply to the `axis` dimension of the transformed
        `Tensor`.
      axis: Scalar `int` `Tensor` representing the dimension over which to
        `tf.gather`. `axis` must be relative to the end (reading left to right)
        thus must be negative.
        Default value: `-1` (i.e., right-most).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str`, name given to ops managed by this object.

    Raises:
      TypeError: if `not dtype_util.is_integer(permutation.dtype)`.
      ValueError: if `permutation` does not contain exactly one of each of
        `{0, 1, ..., d}`.
      NotImplementedError: if `axis` is not known prior to graph execution.
      NotImplementedError: if `axis` is not negative.
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'permute') as name:
            axis = tensor_util.convert_nonref_to_tensor(axis,
                                                        name='axis',
                                                        as_shape_tensor=True)
            if not dtype_util.is_integer(axis.dtype):
                raise TypeError('axis.dtype ({}) should be `int`-like.'.format(
                    dtype_util.name(axis.dtype)))
            axis_ = tf.get_static_value(axis)
            if axis_ is None:
                raise NotImplementedError(
                    '`axis` must be known prior to graph '
                    'execution.')
            elif axis_ >= 0:
                raise NotImplementedError(
                    '`axis` must be relative the rightmost '
                    'dimension, i.e., negative.')
            forward_min_event_ndims = int(np.abs(axis_))
            self._axis = axis

            permutation = tensor_util.convert_nonref_to_tensor(
                permutation, name='permutation')
            if not dtype_util.is_integer(permutation.dtype):
                raise TypeError(
                    'permutation.dtype ({}) should be `int`-like.'.format(
                        dtype_util.name(permutation.dtype)))
            self._permutation = permutation

            super(Permute, self).__init__(
                forward_min_event_ndims=forward_min_event_ndims,
                is_constant_jacobian=True,
                validate_args=validate_args,
                parameters=parameters,
                name=name)
Exemplo n.º 2
0
 def _value(self, dtype=None, name=None, as_ref=False):
     y = self.transform_fn(self.pretransformed_input)  # pylint: disable=not-callable
     if dtype_util.base_dtype(y.dtype) != self.dtype:
         raise TypeError(
             'Actual dtype ({}) does not match deferred dtype ({}).'.format(
                 dtype_util.name(dtype_util.base_dtype(y.dtype)),
                 dtype_util.name(self.dtype)))
     if not tensorshape_util.is_compatible_with(y.shape, self.shape):
         raise TypeError(
             'Actual shape ({}) is incompatible with deferred shape ({}).'.
             format(y.shape, self.shape))
     return tf.convert_to_tensor(y, dtype=dtype, name=name)
Exemplo n.º 3
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init != tensor_util.is_ref(self.permutation):
            if not dtype_util.is_integer(self.permutation.dtype):
                raise TypeError(
                    'permutation.dtype ({}) should be `int`-like.'.format(
                        dtype_util.name(self.permutation.dtype)))

            p = tf.get_static_value(self.permutation)
            if p is not None:
                if set(p) != set(np.arange(p.size)):
                    raise ValueError(
                        'Permutation over `d` must contain exactly one of '
                        'each of `{0, 1, ..., d}`.')

            if self.validate_args:
                p = tf.sort(self.permutation, axis=-1)
                assertions.append(
                    assert_util.assert_equal(
                        p,
                        tf.range(tf.shape(p)[-1]),
                        message=(
                            'Permutation over `d` must contain exactly one of '
                            'each of `{0, 1, ..., d}`.')))

        return assertions
Exemplo n.º 4
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init and not dtype_util.is_integer(
                self.mixture_distribution.dtype):
            raise ValueError(
                '`mixture_distribution.dtype` ({}) is not over integers'.
                format(dtype_util.name(self.mixture_distribution.dtype)))

        if tensorshape_util.rank(
                self.mixture_distribution.event_shape) is not None:
            if tensorshape_util.rank(
                    self.mixture_distribution.event_shape) != 0:
                raise ValueError(
                    '`mixture_distribution` must have scalar `event_dim`s')
        elif self.validate_args:
            assertions += [
                assert_util.assert_equal(
                    tf.size(self.mixture_distribution.event_shape_tensor()),
                    0,
                    message=
                    '`mixture_distribution` must have scalar `event_dim`s'),
            ]

        # pylint: disable=protected-access
        mixture_dist_param = (self.mixture_distribution._probs
                              if self.mixture_distribution._logits is None else
                              self.mixture_distribution._logits)
        km = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                1)[-1])
        kc = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(
                self.components_distribution.batch_shape, 1)[-1])
        component_bst = None
        if km is not None and kc is not None:
            if km != kc:
                raise ValueError(
                    '`mixture_distribution` components ({}) does not '
                    'equal `components_distribution.batch_shape[-1]` '
                    '({})'.format(km, kc))
        elif self.validate_args:
            if km is None:
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                km = tf.shape(mixture_dist_param)[-1]
            if kc is None:
                component_bst = self.components_distribution.batch_shape_tensor(
                )
                kc = component_bst[-1]
            assertions += [
                assert_util.assert_equal(
                    km,
                    kc,
                    message=(
                        '`mixture_distribution` components does not equal '
                        '`components_distribution.batch_shape[-1]`')),
            ]

        return assertions
Exemplo n.º 5
0
  def __init__(self,
               rate=None,
               log_rate=None,
               interpolate_nondiscrete=True,
               validate_args=False,
               allow_nan_stats=True,
               name='Poisson'):
    """Initialize a batch of Poisson distributions.

    Args:
      rate: Floating point tensor, the rate parameter. `rate` must be positive.
        Must specify exactly one of `rate` and `log_rate`.
      log_rate: Floating point tensor, the log of the rate parameter.
        Must specify exactly one of `rate` and `log_rate`.
      interpolate_nondiscrete: Python `bool`. When `False`,
        `log_prob` returns `-inf` (and `prob` returns `0`) for non-integer
        inputs. When `True`, `log_prob` evaluates the continuous function
        `k * log_rate - lgamma(k+1) - rate`, which matches the Poisson pmf
        at integer arguments `k` (note that this function is not itself
        a normalized probability log-density).
        Default value: `True`.
      validate_args: Python `bool`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if none or both of `rate`, `log_rate` are specified.
      TypeError: if `rate` is not a float-type.
      TypeError: if `log_rate` is not a float-type.
    """
    parameters = dict(locals())
    if (rate is None) == (log_rate is None):
      raise ValueError('Must specify exactly one of `rate` and `log_rate`.')
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([rate, log_rate], dtype_hint=tf.float32)
      if not dtype_util.is_floating(dtype):
        raise TypeError('[log_]rate.dtype ({}) is a not a float-type.'.format(
            dtype_util.name(dtype)))
      self._rate = tensor_util.convert_nonref_to_tensor(
          rate, name='rate', dtype=dtype)
      self._log_rate = tensor_util.convert_nonref_to_tensor(
          log_rate, name='log_rate', dtype=dtype)

      self._interpolate_nondiscrete = interpolate_nondiscrete
      super(Poisson, self).__init__(
          dtype=dtype,
          reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          name=name)
Exemplo n.º 6
0
 def trace_(x):
     """Prints something."""
     if hasattr(x, 'dtype') and hasattr(x, 'shape'):
         print('--- TRACE:  {}shape:{:16}  dtype:{:10}'.format(
             name, str(tensorshape_util.as_list(x.shape)),
             dtype_util.name(x.dtype)))
     else:
         print('--- TRACE:  {}value:{}'.format(name, x))
     sys.stdout.flush()
     return x
 def __repr__(self):
     return ('<tfp.math.psd_kernels.{type_name} '
             '\'{self_name}\''
             ' batch_shape={batch_shape}'
             ' feature_ndims={feature_ndims}'
             ' dtype={dtype}>'.format(type_name=type(self).__name__,
                                      self_name=self.name,
                                      batch_shape=self.batch_shape,
                                      feature_ndims=self.feature_ndims,
                                      dtype=None if self.dtype is None else
                                      dtype_util.name(self.dtype)))
Exemplo n.º 8
0
def _maybe_validate_distributions(distributions, dtype_override,
                                  validate_args):
    """Checks that `distributions` satisfies all assumptions."""
    assertions = []

    if not _is_iterable(distributions) or not distributions:
        raise ValueError('`distributions` must be a list of one or more '
                         'distributions.')

    if dtype_override is None:
        dts = [
            dtype_util.base_dtype(d.dtype) for d in distributions
            if d.dtype is not None
        ]
        if dts[1:] != dts[:-1]:
            raise TypeError(
                'Distributions must have same dtype; found: {}.'.format(
                    set(dtype_util.name(dt) for dt in dts)))

    # Validate event_ndims.
    for d in distributions:
        if tensorshape_util.rank(d.event_shape) is not None:
            if tensorshape_util.rank(d.event_shape) != 1:
                raise ValueError('`Distribution` must be vector variate, '
                                 'found event nimds: {}.'.format(
                                     tensorshape_util.rank(d.event_shape)))
        elif validate_args:
            assertions.append(
                assert_util.assert_equal(
                    1,
                    tf.size(d.event_shape_tensor()),
                    message='`Distribution` must be vector variate.'))

    batch_shapes = [d.batch_shape for d in distributions]
    if all(tensorshape_util.is_fully_defined(b) for b in batch_shapes):
        if batch_shapes[1:] != batch_shapes[:-1]:
            raise ValueError('Distributions must have the same `batch_shape`; '
                             'found: {}.'.format(batch_shapes))
    elif validate_args:
        batch_shapes = [
            tensorshape_util.as_list(d.batch_shape)  # pylint: disable=g-complex-comprehension
            if tensorshape_util.is_fully_defined(d.batch_shape) else
            d.batch_shape_tensor() for d in distributions
        ]
        assertions.extend(
            assert_util.assert_equal(  # pylint: disable=g-complex-comprehension
                b1,
                b2,
                message='Distribution `batch_shape`s must be identical.')
            for b1, b2 in zip(batch_shapes[1:], batch_shapes[:-1]))

    return assertions
Exemplo n.º 9
0
 def __init__(self, distribution, dtype):
   parameters = dict(locals())
   name = 'CastTo{}'.format(dtype_util.name(dtype))
   with tf.name_scope(name) as name:
     self._distribution = distribution
     self._dtype = dtype
     super(_Cast, self).__init__(
         dtype=dtype,
         validate_args=distribution.validate_args,
         allow_nan_stats=distribution.allow_nan_stats,
         reparameterization_type=distribution.reparameterization_type,
         parameters=parameters,
         name=name)
 def __str__(self):
     return ('tfp.math.psd_kernels.{type_name}('
             '"{self_name}"'
             '{maybe_batch_shape}'
             ', feature_ndims={feature_ndims}'
             ', dtype={dtype})'.format(
                 type_name=type(self).__name__,
                 self_name=self.name,
                 maybe_batch_shape=(', batch_shape={}'.format(
                     self.batch_shape) if tensorshape_util.rank(
                         self.batch_shape) is not None else ''),
                 feature_ndims=self.feature_ndims,
                 dtype=None
                 if self.dtype is None else dtype_util.name(self.dtype)))
Exemplo n.º 11
0
  def poisson_and_mixture_distributions(self):
    """Returns the Poisson and Mixture distribution parameterized by the quadrature grid and weights."""
    loc = tf.convert_to_tensor(self.loc)
    scale = tf.convert_to_tensor(self.scale)
    quadrature_grid, quadrature_probs = tuple(self._quadrature_fn(
        loc, scale, self.quadrature_size, self.validate_args))
    dt = quadrature_grid.dtype
    if not dtype_util.base_equal(dt, quadrature_probs.dtype):
      raise TypeError('Quadrature grid dtype ({}) does not match quadrature '
                      'probs dtype ({}).'.format(
                          dtype_util.name(dt),
                          dtype_util.name(quadrature_probs.dtype)))

    dist = poisson.Poisson(
        log_rate=quadrature_grid,
        validate_args=self.validate_args,
        allow_nan_stats=self.allow_nan_stats)

    mixture_dist = categorical.Categorical(
        logits=tf.math.log(quadrature_probs),
        validate_args=self.validate_args,
        allow_nan_stats=self.allow_nan_stats)
    return dist, mixture_dist
Exemplo n.º 12
0
def _maybe_validate_matrix(a, validate_args):
  """Checks that input is a `float` matrix."""
  assertions = []
  if not dtype_util.is_floating(a.dtype):
    raise TypeError('Input `a` must have `float`-like `dtype` '
                    '(saw {}).'.format(dtype_util.name(a.dtype)))
  if tensorshape_util.rank(a.shape) is not None:
    if tensorshape_util.rank(a.shape) < 2:
      raise ValueError('Input `a` must have at least 2 dimensions '
                       '(saw: {}).'.format(tensorshape_util.rank(a.shape)))
  elif validate_args:
    assertions.append(assert_util.assert_rank_at_least(
        a, rank=2, message='Input `a` must have at least 2 dimensions.'))
  return assertions
Exemplo n.º 13
0
 def __repr__(self):
     if tf.executing_eagerly():
         try:
             value = self._value()
         except Exception as e:  # pylint: disable=broad-except
             value = e
         value_str = ', numpy={}'.format(
             value if isinstance(value, Exception) else _numpy_text(value))
     else:
         value_str = ''
     return '<{}: dtype={}, shape={}, fn={}{}>'.format(
         type(self).__name__,
         dtype_util.name(self.dtype) if self.dtype else '?',
         str(
             tensorshape_util.as_list(self.shape) if tensorshape_util.
             rank(self.shape) is not None else '?').replace('None', '?'),
         self._fwd_name, value_str)
Exemplo n.º 14
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []
        sample_shape = None  # Memoize concretization.

        # Check valid shape.
        ndims_ = tensorshape_util.rank(self.sample_shape.shape)
        if is_init != (ndims_ is None):
            msg = 'Argument `sample_shape` must be either a scalar or a vector.'
            if ndims_ is not None:
                if ndims_ > 1:
                    raise ValueError(msg)
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_less(tf.rank(sample_shape),
                                            2,
                                            message=msg))

        # Check valid dtype.
        if is_init:  # No xor check because `dtype` cannot change.
            dtype_ = self.sample_shape.dtype
            if dtype_ is None:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                dtype_ = sample_shape.dtype
            if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}:
                raise TypeError(
                    'Argument `sample_shape` must be integer type; '
                    'saw {}.'.format(dtype_util.name(dtype_)))

        # Check valid "value".
        if is_init != tensor_util.is_ref(self.sample_shape):
            sample_shape_ = tf.get_static_value(self.sample_shape)
            msg = 'Argument `sample_shape` must have non-negative values.'
            if sample_shape_ is not None:
                if np.any(np.array(sample_shape_) < 0):
                    raise ValueError('{} Saw: {}'.format(msg, sample_shape_))
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_greater(sample_shape, -1, message=msg))

        return assertions
Exemplo n.º 15
0
def _maybe_check_valid_shape(shape, validate_args):
    """Check that a shape Tensor is int-type and otherwise sane."""
    if not dtype_util.is_integer(shape.dtype):
        raise TypeError('`{}` dtype (`{}`) should be `int`-like.'.format(
            shape, dtype_util.name(shape.dtype)))

    assertions = []

    message = '`{}` rank should be <= 1.'
    if tensorshape_util.rank(shape.shape) is not None:
        if tensorshape_util.rank(shape.shape) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.rank(shape),
                                    2,
                                    message=message.format(shape)))

    shape_ = tf.get_static_value(shape)

    message = '`{}` elements must have at most one `-1`.'
    if shape_ is not None:
        if sum(shape_ == -1) > 1:
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_less(tf.reduce_sum(
                tf.cast(tf.equal(shape, -1), tf.int32)),
                                    2,
                                    message=message.format(shape)))

    message = '`{}` elements must be either positive integers or `-1`.'
    if shape_ is not None:
        if np.any(shape_ < -1):
            raise ValueError(message.format(shape))
    elif validate_args:
        assertions.append(
            assert_util.assert_greater(shape,
                                       -2,
                                       message=message.format(shape)))

    return assertions
Exemplo n.º 16
0
def erfinv(x, name="erfinv"):
  """The inverse function for erf, the error function.

  Args:
    x: `Tensor` of type `float32`, `float64`.
    name: Python string. A name for the operation (default="erfinv").

  Returns:
    x: `Tensor` with `dtype=x.dtype`.

  Raises:
    TypeError: if `x` is not floating-type.
  """

  with tf.name_scope(name):
    x = tf.convert_to_tensor(x, name="x")
    if dtype_util.as_numpy_dtype(x.dtype) not in [np.float32, np.float64]:
      raise TypeError("x.dtype={} is not handled, see docstring for supported "
                      "types.".format(dtype_util.name(x.dtype)))
    return ndtri((x + 1.) / 2.) / np.sqrt(2.)
    def _maybe_validate_shape_override(self, override_shape, base_is_scalar_fn,
                                       static_base_shape, is_init):
        """Helper which ensures override batch/event_shape are valid."""

        assertions = []
        concretized_shape = None

        # Check valid dtype
        if is_init:  # No xor check because `dtype` cannot change.
            dtype_ = override_shape.dtype
            if dtype_ is None:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                dtype_ = concretized_shape.dtype
            if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}:
                raise TypeError('Shape override must be integer type; '
                                'saw {}.'.format(dtype_util.name(dtype_)))

        # Check non-negative elements
        if is_init != tensor_util.is_ref(override_shape):
            override_shape_ = tf.get_static_value(override_shape)
            msg = 'Shape override must have non-negative elements.'
            if override_shape_ is not None:
                if np.any(np.array(override_shape_) < 0):
                    raise ValueError('{} Saw: {}'.format(msg, override_shape_))
            elif self.validate_args:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                assertions.append(
                    assert_util.assert_non_negative(concretized_shape,
                                                    message=msg))

        # Check valid shape
        override_ndims_ = tensorshape_util.rank(override_shape.shape)
        if is_init != (override_ndims_ is None):
            msg = 'Shape override must be a vector.'
            if override_ndims_ is not None:
                if override_ndims_ != 1:
                    raise ValueError(msg)
            elif self.validate_args:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                override_rank = tf.rank(concretized_shape)
                assertions.append(
                    assert_util.assert_equal(override_rank, 1, message=msg))

        static_base_rank = tensorshape_util.rank(static_base_shape)

        # Determine if the override shape is `[]` (static_override_dims == [0]),
        # in which case the base distribution may be nonscalar.
        static_override_dims = tensorshape_util.dims(override_shape.shape)

        if is_init != (static_base_rank is None
                       or static_override_dims is None):
            msg = 'Base distribution is not scalar.'
            if static_base_rank is not None and static_override_dims is not None:
                if static_base_rank != 0 and static_override_dims != [0]:
                    raise ValueError(msg)
            elif self.validate_args:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                override_is_empty = tf.logical_not(
                    self._has_nonzero_rank(concretized_shape))
                assertions.append(
                    assert_util.assert_equal(tf.logical_or(
                        base_is_scalar_fn(), override_is_empty),
                                             True,
                                             message=msg))
        return assertions
Exemplo n.º 18
0
    def __init__(self,
                 mix_loc,
                 temperature,
                 distribution,
                 loc=None,
                 scale=None,
                 quadrature_size=8,
                 quadrature_fn=quadrature_scheme_softmaxnormal_quantiles,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorDiffeomixture"):
        """Constructs the VectorDiffeomixture on `R^d`.

    The vector diffeomixture (VDM) approximates the compound distribution

    ```none
    p(x) = int p(x | z) p(z) dz,
    where z is in the K-simplex, and
    p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k])
    ```

    Args:
      mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`.
        In terms of samples, larger `mix_loc[..., k]` ==>
        `Z` is more likely to put more weight on its `kth` component.
      temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`.
        In terms of samples, smaller `temperature` means one component is more
        likely to dominate.  I.e., smaller `temperature` makes the VDM look more
        like a standard mixture of `K` components.
      distribution: `tfp.distributions.Distribution`-like instance. Distribution
        from which `d` iid samples are used as input to the selected affine
        transformation. Must be a scalar-batch, scalar-event distribution.
        Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED`
        or it is a function of non-trainable parameters. WARNING: If you
        backprop through a VectorDiffeomixture sample and the `distribution`
        is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables,
        then the gradient will be incorrect!
      loc: Length-`K` list of `float`-type `Tensor`s. The `k`-th element
        represents the `shift` used for the `k`-th affine transformation.  If
        the `k`-th item is `None`, `loc` is implicitly `0`.  When specified,
        must have shape `[B1, ..., Bb, d]` where `b >= 0` and `d` is the event
        size.
      scale: Length-`K` list of `LinearOperator`s. Each should be
        positive-definite and operate on a `d`-dimensional vector space. The
        `k`-th element represents the `scale` used for the `k`-th affine
        transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
        `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
      quadrature_size: Python `int` scalar representing number of
        quadrature points.  Larger `quadrature_size` means `q_N(x)` better
        approximates `p(x)`.
      quadrature_fn: Python callable taking `normal_loc`, `normal_scale`,
        `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
        representing the SoftmaxNormal grid and corresponding normalized weight.
        normalized) weight.
        Default value: `quadrature_scheme_softmaxnormal_quantiles`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if `not scale or len(scale) < 2`.
      ValueError: if `len(loc) != len(scale)`
      ValueError: if `quadrature_grid_and_probs is not None` and
        `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
      ValueError: if `validate_args` and any not scale.is_positive_definite.
      TypeError: if any scale.dtype != scale[0].dtype.
      TypeError: if any loc.dtype != scale[0].dtype.
      NotImplementedError: if `len(scale) != 2`.
      ValueError: if `not distribution.is_scalar_batch`.
      ValueError: if `not distribution.is_scalar_event`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            if not scale or len(scale) < 2:
                raise ValueError(
                    "Must specify list (or list-like object) of scale "
                    "LinearOperators, one for each component with "
                    "num_component >= 2.")

            if loc is None:
                loc = [None] * len(scale)

            if len(loc) != len(scale):
                raise ValueError("loc/scale must be same-length lists "
                                 "(or same-length list-like objects).")

            dtype = dtype_util.base_dtype(scale[0].dtype)

            loc = [
                tf.convert_to_tensor(
                    value=loc_, dtype=dtype, name="loc{}".format(k))
                if loc_ is not None else None for k, loc_ in enumerate(loc)
            ]

            for k, scale_ in enumerate(scale):
                if validate_args and not scale_.is_positive_definite:
                    raise ValueError(
                        "scale[{}].is_positive_definite = {} != True".format(
                            k, scale_.is_positive_definite))
                if dtype_util.base_dtype(scale_.dtype) != dtype:
                    raise TypeError(
                        "dtype mismatch; scale[{}].base_dtype=\"{}\" != \"{}\""
                        .format(k, dtype_util.name(scale_.dtype),
                                dtype_util.name(dtype)))

            self._endpoint_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="endpoint_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(zip(loc, scale))
            ]

            # TODO(jvdillon): Remove once we support k-mixtures.
            # We make this assertion here because otherwise `grid` would need to be a
            # vector not a scalar.
            if len(scale) != 2:
                raise NotImplementedError(
                    "Currently only bimixtures are supported; "
                    "len(scale)={} is not 2.".format(len(scale)))

            mix_loc = tf.convert_to_tensor(value=mix_loc,
                                           dtype=dtype,
                                           name="mix_loc")
            temperature = tf.convert_to_tensor(value=temperature,
                                               dtype=dtype,
                                               name="temperature")
            self._grid, probs = tuple(
                quadrature_fn(mix_loc / temperature, 1. / temperature,
                              quadrature_size, validate_args))

            # Note: by creating the logits as `log(prob)` we ensure that
            # `self.mixture_distribution.logits` is equivalent to
            # `math_ops.log(self.mixture_distribution.probs)`.
            self._mixture_distribution = categorical.Categorical(
                logits=tf.math.log(probs),
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            asserts = distribution_util.maybe_check_scalar_distribution(
                distribution, dtype, validate_args)
            if asserts:
                self._grid = distribution_util.with_dependencies(
                    asserts, self._grid)
            self._distribution = distribution

            self._interpolated_affine = [
                affine_linear_operator_bijector.AffineLinearOperator(  # pylint: disable=g-complex-comprehension
                    shift=loc_,
                    scale=scale_,
                    validate_args=validate_args,
                    name="interpolated_affine_{}".format(k))
                for k, (loc_, scale_) in enumerate(
                    zip(interpolate_loc(self._grid, loc),
                        interpolate_scale(self._grid, scale)))
            ]

            [
                self._batch_shape_,
                self._batch_shape_tensor_,
                self._event_shape_,
                self._event_shape_tensor_,
            ] = determine_batch_event_shapes(self._grid, self._endpoint_affine)

            super(VectorDiffeomixture, self).__init__(
                dtype=dtype,
                # We hard-code `FULLY_REPARAMETERIZED` because when
                # `validate_args=True` we verify that indeed
                # `distribution.reparameterization_type == FULLY_REPARAMETERIZED`. A
                # distribution which is a function of only non-trainable parameters
                # also implies we can use `FULLY_REPARAMETERIZED`. However, we cannot
                # easily test for that possibility thus we use `validate_args=False`
                # as a "back-door" to allow users a way to use non
                # `FULLY_REPARAMETERIZED` distribution. In such cases IT IS THE USERS
                # RESPONSIBILITY to verify that the base distribution is a function of
                # non-trainable parameters.
                reparameterization_type=reparameterization.
                FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(
                    distribution._graph_parents  # pylint: disable=protected-access
                    + [loc_ for loc_ in loc if loc_ is not None] +
                    [p for scale_ in scale for p in scale_.graph_parents]),  # pylint: disable=g-complex-comprehension
                name=name)
Exemplo n.º 19
0
def sample_lkj(
    num_samples,
    dimension,
    concentration,
    cholesky_space=False,
    seed=None,
    name=None):
  """Returns a Tensor of samples from an LKJ distribution.

  Args:
    num_samples: Python `int`. The number of samples to draw.
    dimension: Python `int`. The dimension of correlation matrices.
    concentration: `Tensor` representing the concentration of the LKJ
      distribution.
    cholesky_space: Python `bool`. Whether to take samples from LKJ or
      Chol(LKJ).
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    samples: A Tensor of correlation matrices (or Cholesky factors of
      correlation matrices if `cholesky_space = True`) with shape
      `[n] + B + [D, D]`, where `B` is the shape of the `concentration`
      parameter, and `D` is the `dimension`.

  Raises:
    ValueError: If `dimension` is negative.
  """
  if dimension < 0:
    raise ValueError(
        'Cannot sample negative-dimension correlation matrices.')
  # Notation below: B is the batch shape, i.e., tf.shape(concentration)

  with tf.name_scope('sample_lkj' or name):
    concentration = tf.convert_to_tensor(concentration)
    if not dtype_util.is_floating(concentration.dtype):
      raise TypeError(
          'The concentration argument should have floating type, not '
          '{}'.format(dtype_util.name(concentration.dtype)))

    batch_shape = ps.concat([[num_samples], ps.shape(concentration)], axis=0)
    dtype = concentration.dtype
    if dimension <= 1:
      # For any dimension <= 1, there is only one possible correlation matrix.
      shape = ps.concat([batch_shape, [dimension, dimension]], axis=0)
      return tf.ones(shape=shape, dtype=dtype)

    # We need 1 seed for beta and 1 seed for tril_spherical_uniform.
    beta_seed, tril_spherical_uniform_seed = samplers.split_seed(
        seed, n=2, salt='sample_lkj')

    # Note that the sampler below deviates from [1], by doing the sampling in
    # cholesky space. This does not change the fundamental logic of the
    # sampler, but does speed up the sampling.
    # In addition, we also vectorize the computation to make the sampler
    # more feasible to use in problems where `dimension` is large.

    beta_conc = concentration + (dimension - 2.) / 2.
    dimension_range = np.arange(
        1., dimension, dtype=dtype_util.as_numpy_dtype(dtype))
    beta_conc1 = dimension_range / 2.
    beta_conc0 = beta_conc[..., tf.newaxis] - (dimension_range - 1) / 2.
    beta_dist = beta.Beta(concentration1=beta_conc1, concentration0=beta_conc0)
    # norm is y in reference [1].
    norm = beta_dist.sample(sample_shape=[num_samples], seed=beta_seed)
    # distance shape: B + [dimension - 1, 1] for broadcast
    distance = tf.sqrt(norm)[..., tf.newaxis]

    # direction is u in reference [1].
    # direction follows the spherical uniform distribution and will be stored
    # in a lower triangular matrix, hence it will have shape:
    # B + [dimension - 1, dimension - 1]
    direction = _tril_spherical_uniform(dimension - 1, batch_shape, dtype,
                                        tril_spherical_uniform_seed)

    # raw_correlation is w in reference [1].
    # shape: B + [dimension - 1, dimension - 1]
    raw_correlation = distance * direction

    # This is the rows in the cholesky of the result,
    # which differs from the construction in reference [1].
    # In the reference, the new row `z` = chol_result @ raw_correlation^T
    # = C @ raw_correlation^T (where as short hand we use C = chol_result).
    # We prove that the below equation is the right row to add to the
    # cholesky, by showing equality with reference [1].
    # Let S be the sample constructed so far, and let `z` be as in
    # reference [1]. Then at this iteration, the new sample S' will be
    # [[S z^T]
    #  [z 1]]
    # In our case we have the cholesky decomposition factor C, so
    # we want our new row x (same size as z) to satisfy:
    #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
    #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
    # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
    # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
    # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
    # distance**2).
    paddings_prepend = [[0, 0]] * len(batch_shape)
    diag = tf.pad(
        tf.sqrt(1. - norm), paddings_prepend + [[1, 0]], constant_values=1.)
    chol_result = tf.pad(
        raw_correlation,
        paddings_prepend + [[1, 0], [0, 1]],
        constant_values=0.)
    chol_result = tf.linalg.set_diag(chol_result, diag)

    if cholesky_space:
      return chol_result

    result = tf.matmul(chol_result, chol_result, transpose_b=True)
    # The diagonal for a correlation matrix should always be ones. Due to
    # numerical instability the matmul might not achieve that, so manually set
    # these to ones.
    result = tf.linalg.set_diag(
        result, tf.ones(shape=ps.shape(result)[:-1], dtype=result.dtype))
    # This sampling algorithm can produce near-PSD matrices on which standard
    # algorithms such as `tf.linalg.cholesky` or
    # `tf.linalg.self_adjoint_eigvals` fail. Specifically, as documented in
    # b/116828694, around 2% of trials of 900,000 5x5 matrices (distributed
    # according to 9 different concentration parameter values) contained at
    # least one matrix on which the Cholesky decomposition failed.
    return result
Exemplo n.º 20
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    # Check num_steps is a scalar that's at least 1.
    if is_init != tensor_util.is_ref(self.num_steps):
      num_steps = tf.convert_to_tensor(self.num_steps)
      num_steps_ = tf.get_static_value(num_steps)
      if num_steps_ is not None:
        if np.ndim(num_steps_) != 0:
          raise ValueError(
              '`num_steps` must be a scalar but it has rank {}'.format(
                  np.ndim(num_steps_)))
        if num_steps_ < 1:
          raise ValueError('`num_steps` must be at least 1.')
      elif self.validate_args:
        message = '`num_steps` must be a scalar'
        assertions.append(
            assert_util.assert_rank_at_most(self.num_steps, 0, message=message))
        assertions.append(
            assert_util.assert_greater_equal(
                num_steps, 1,
                message='`num_steps` must be at least 1.'))

    # Check that the initial distribution has scalar events over the
    # integers.
    if is_init and not dtype_util.is_integer(self.initial_distribution.dtype):
      raise ValueError(
          '`initial_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.initial_distribution.dtype)))

    if tensorshape_util.rank(self.initial_distribution.event_shape) is not None:
      if tensorshape_util.rank(self.initial_distribution.event_shape) != 0:
        raise ValueError('`initial_distribution` must have scalar `event_dim`s')
    elif self.validate_args:
      assertions += [
          assert_util.assert_equal(
              ps.size(self.initial_distribution.event_shape_tensor()),
              0,
              message='`initial_distribution` must have scalar `event_dim`s'),
      ]

    # Check that the transition distribution is over the integers.
    if (is_init and
        not dtype_util.is_integer(self.transition_distribution.dtype)):
      raise ValueError(
          '`transition_distribution.dtype` ({}) is not over integers'.format(
              dtype_util.name(self.transition_distribution.dtype)))

    # Check observations have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.observation_distribution.batch_shape) == 0:
      raise ValueError(
          "`observation_distribution` can't have scalar batches")

    # Check transitions have non-scalar batches.
    # The graph version of this assertion is incorporated as
    # a control dependency of the transition/observation
    # compatibility test.
    if tensorshape_util.rank(self.transition_distribution.batch_shape) == 0:
      raise ValueError(
          "`transition_distribution` can't have scalar batches")

    # Check compatibility of transition distribution and observation
    # distribution.
    tdbs = self.transition_distribution.batch_shape
    odbs = self.observation_distribution.batch_shape
    if (tensorshape_util.dims(tdbs) is not None and
        tf.compat.dimension_value(odbs[-1]) is not None):
      if (tf.compat.dimension_value(tdbs[-1]) !=
          tf.compat.dimension_value(odbs[-1])):
        raise ValueError(
            '`transition_distribution` and `observation_distribution` '
            'must agree on last dimension of batch size')
    elif self.validate_args:
      tdbs = self.transition_distribution.batch_shape_tensor()
      odbs = self.observation_distribution.batch_shape_tensor()
      transition_precondition = assert_util.assert_greater(
          ps.size(tdbs), 0,
          message=('`transition_distribution` can\'t have scalar '
                   'batches'))
      observation_precondition = assert_util.assert_greater(
          ps.size(odbs), 0,
          message=('`observation_distribution` can\'t have scalar '
                   'batches'))
      with tf.control_dependencies([
          transition_precondition,
          observation_precondition]):
        assertions += [
            assert_util.assert_equal(
                tdbs[-1],
                odbs[-1],
                message=('`transition_distribution` and '
                         '`observation_distribution` '
                         'must agree on last dimension of batch size'))]

    return assertions
Exemplo n.º 21
0
    def __init__(self,
                 mixture_distribution,
                 components_distribution,
                 reparameterize=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MixtureSameFamily"):
        """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      reparameterize: Python `bool`, default `False`. Whether to reparameterize
        samples of the distribution using implicit reparameterization gradients
        [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are
        equivalent to the ones described by [(Graves, 2016)][2]. The gradients
        for the components parameters are also computed using implicit
        reparameterization (as opposed to ancestral sampling), meaning that
        all components are updated every step.
        Only works when:
          (1) components_distribution is fully reparameterized;
          (2) components_distribution is either a scalar distribution or
          fully factorized (tfd.Independent applied to a scalar distribution);
          (3) batch shape has a known rank.
        Experimental, may be slow and produce infs/NaNs.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not dtype_util.is_integer(mixture_distribution.dtype)`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.

    #### References

    [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit
         reparameterization gradients. In _Neural Information Processing
         Systems_, 2018. https://arxiv.org/abs/1805.08498

    [2]: Alex Graves. Stochastic Backpropagation through Mixture Density
         Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._mixture_distribution = mixture_distribution
            self._components_distribution = components_distribution
            self._runtime_assertions = []

            s = components_distribution.event_shape_tensor()
            self._event_ndims = tf.compat.dimension_value(s.shape[0])
            if self._event_ndims is None:
                self._event_ndims = tf.size(s)
            self._event_size = tf.reduce_prod(s)

            if not dtype_util.is_integer(mixture_distribution.dtype):
                raise ValueError(
                    "`mixture_distribution.dtype` ({}) is not over integers".
                    format(dtype_util.name(mixture_distribution.dtype)))

            if (tensorshape_util.rank(mixture_distribution.event_shape)
                    is not None and tensorshape_util.rank(
                        mixture_distribution.event_shape) != 0):
                raise ValueError(
                    "`mixture_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.size(mixture_distribution.event_shape_tensor()),
                        0,
                        message=
                        "`mixture_distribution` must have scalar `event_dim`s"
                    ),
                ]

            mdbs = mixture_distribution.batch_shape
            cdbs = tensorshape_util.with_rank_at_least(
                components_distribution.batch_shape, 1)[:-1]
            if tensorshape_util.is_fully_defined(
                    mdbs) and tensorshape_util.is_fully_defined(cdbs):
                if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                    raise ValueError(
                        "`mixture_distribution.batch_shape` (`{}`) is not "
                        "compatible with `components_distribution.batch_shape` "
                        "(`{}`)".format(tensorshape_util.as_list(mdbs),
                                        tensorshape_util.as_list(cdbs)))
            elif validate_args:
                mdbs = mixture_distribution.batch_shape_tensor()
                cdbs = components_distribution.batch_shape_tensor()[:-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        distribution_utils.pick_vector(
                            mixture_distribution.is_scalar_batch(), cdbs,
                            mdbs),
                        cdbs,
                        message=
                        ("`mixture_distribution.batch_shape` is not "
                         "compatible with `components_distribution.batch_shape`"
                         ))
                ]

            mixture_dist_param = (mixture_distribution.probs
                                  if mixture_distribution.logits is None else
                                  mixture_distribution.logits)
            km = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                    1)[-1])
            kc = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(
                    components_distribution.batch_shape, 1)[-1])
            if km is not None and kc is not None and km != kc:
                raise ValueError(
                    "`mixture_distribution components` ({}) does not "
                    "equal `components_distribution.batch_shape[-1]` "
                    "({})".format(km, kc))
            elif validate_args:
                km = tf.shape(mixture_dist_param)[-1]
                kc = components_distribution.batch_shape_tensor()[-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        km,
                        kc,
                        message=(
                            "`mixture_distribution components` does not equal "
                            "`components_distribution.batch_shape[-1:]`")),
                ]
            elif km is None:
                km = tf.shape(mixture_dist_param)[-1]

            self._num_components = km

            self._reparameterize = reparameterize
            if reparameterize:
                # Note: tfd.Independent passes through the reparameterization type hence
                # we do not need separate logic for Independent.
                if (self._components_distribution.reparameterization_type !=
                        reparameterization.FULLY_REPARAMETERIZED):
                    raise ValueError("Cannot reparameterize a mixture of "
                                     "non-reparameterized components.")
                reparameterization_type = reparameterization.FULLY_REPARAMETERIZED
            else:
                reparameterization_type = reparameterization.NOT_REPARAMETERIZED

            super(MixtureSameFamily, self).__init__(
                dtype=self._components_distribution.dtype,
                reparameterization_type=reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
Exemplo n.º 22
0
  def __init__(self,
               power,
               dtype=tf.int32,
               interpolate_nondiscrete=True,
               sample_maximum_iterations=100,
               validate_args=False,
               allow_nan_stats=False,
               name='Zipf'):
    """Initialize a batch of Zipf distributions.

    Args:
      power: `Float` like `Tensor` representing the power parameter. Must be
        strictly greater than `1`.
      dtype: The `dtype` of `Tensor` returned by `sample`.
        Default value: `tf.int32`.
      interpolate_nondiscrete: Python `bool`. When `False`, `log_prob` returns
        `-inf` (and `prob` returns `0`) for non-integer inputs. When `True`,
        `log_prob` evaluates the continuous function `-power log(k) -
        log(zeta(power))` , which matches the Zipf pmf at integer arguments `k`
        (note that this function is not itself a normalized probability
        log-density).
        Default value: `True`.
      sample_maximum_iterations: Maximum number of iterations of allowable
        iterations in `sample`. When `validate_args=True`, samples which fail to
        reach convergence (subject to this cap) are masked out with
        `self.dtype.min` or `nan` depending on `self.dtype.is_integer`.
        Default value: `100`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or more
        of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Zipf'`.

    Raises:
      TypeError: if `power` is not `float` like.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      self._power = tensor_util.convert_nonref_to_tensor(
          power,
          name='power',
          dtype=dtype_util.common_dtype([power], dtype_hint=tf.float32))
      if (not dtype_util.is_floating(self._power.dtype) or
          dtype_util.base_equal(self._power.dtype, tf.float16)):
        raise TypeError(
            'power.dtype ({}) is not a supported `float` type.'.format(
                dtype_util.name(self._power.dtype)))
      self._interpolate_nondiscrete = interpolate_nondiscrete
      self._sample_maximum_iterations = sample_maximum_iterations
      super(Zipf, self).__init__(
          dtype=dtype,
          reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          name=name)
Exemplo n.º 23
0
def assign_log_moving_mean_exp(log_value, moving_log_mean_exp,
                               zero_debias_count=None, decay=0.99, name=None):
  """Compute the log of the exponentially weighted moving mean of the exp.

  If `log_value` is a draw from a stationary random variable, this function
  approximates `log(E[exp(log_value)])`, i.e., a weighted log-sum-exp. More
  precisely, a `tf.Variable`, `moving_log_mean_exp`, is updated by `log_value`
  using the following identity:

  ```none
  moving_log_mean_exp =
  = log(decay exp(moving_log_mean_exp) + (1 - decay) exp(log_value))
  = log(exp(moving_log_mean_exp + log(decay)) + exp(log_value + log1p(-decay)))
  = moving_log_mean_exp
    + log(  exp(moving_log_mean_exp   - moving_log_mean_exp + log(decay))
          + exp(log_value - moving_log_mean_exp + log1p(-decay)))
  = moving_log_mean_exp
    + log_sum_exp([log(decay), log_value - moving_log_mean_exp +
    log1p(-decay)]).
  ```

  In addition to numerical stability, this formulation is advantageous because
  `moving_log_mean_exp` can be updated in a lock-free manner, i.e., using
  `assign_add`. (Note: the updates are not thread-safe; it's just that the
  update to the tf.Variable is presumed efficient due to being lock-free.)

  Args:
    log_value: `float`-like `Tensor` representing a new (streaming) observation.
      Same shape as `moving_log_mean_exp`.
    moving_log_mean_exp: `float`-like `Variable` representing the log of the
      exponentially weighted moving mean of the exp. Same shape as `log_value`.
    zero_debias_count: `int`-like `tf.Variable` representing the number of times
      this function has been called on streaming input (*not* the number of
      reduced values used in this functions computation). When not `None` (the
      default) the returned values for `moving_mean` and `moving_variance` are
      "zero debiased", i.e., corrected for their presumed all zeros
      intialization. Note: the `tf.Variable`s `moving_mean` and
      `moving_variance` *always* store the unbiased calculation, regardless of
      setting this argument. To obtain unbiased calculations from these
      `tf.Variable`s, see `tfp.stats.moving_mean_variance_zero_debiased`.
      Default value: `None` (i.e., no zero debiasing calculation is made).
    decay: A `float`-like `Tensor` representing the moving mean decay. Typically
      close to `1.`, e.g., `0.99`.
      Default value: `0.99`.
    name: Python `str` prepended to op names created by this function.
      Default value: `None` (i.e., 'assign_log_moving_mean_exp').

  Returns:
    moving_log_mean_exp: A reference to the input 'Variable' tensor with the
      `log_value`-updated log of the exponentially weighted moving mean of exp.

  Raises:
    TypeError: if `moving_log_mean_exp` does not have float type `dtype`.
    TypeError: if `moving_log_mean_exp`, `log_value`, `decay` have different
      `base_dtype`.
  """
  if zero_debias_count is not None:
    raise NotImplementedError(
        'Argument `zero_debias_count` is not yet supported. If you require '
        'this feature please create a new issue on '
        '`https://github.com/tensorflow/probability` or email '
        '`[email protected]`.')
  with tf.name_scope(name or 'assign_log_moving_mean_exp'):
    # We want to update the variable in a numerically stable and lock-free way.
    # To do this, observe that variable `x` updated by `v` is:
    # x = log(w exp(x) + (1-w) exp(v))
    #   = log(exp(x + log(w)) + exp(v + log1p(-w)))
    #   = x + log(exp(x - x + log(w)) + exp(v - x + log1p(-w)))
    #   = x + lse([log(w), v - x + log1p(-w)])
    base_dtype = dtype_util.base_dtype(moving_log_mean_exp.dtype)
    if not dtype_util.is_floating(base_dtype):
      raise TypeError(
          'Argument `moving_log_mean_exp` is not float type (saw {}).'.format(
              dtype_util.name(moving_log_mean_exp.dtype)))
    log_value = tf.convert_to_tensor(
        log_value, dtype=base_dtype, name='log_value')
    decay = tf.convert_to_tensor(decay, dtype=base_dtype, name='decay')
    delta = (log_value - moving_log_mean_exp)[tf.newaxis, ...]
    x = tf.concat([
        tf.broadcast_to(
            tf.math.log(decay),
            prefer_static.broadcast_shape(prefer_static.shape(decay),
                                          prefer_static.shape(delta))),
        delta + tf.math.log1p(-decay)
    ], axis=0)
    update = tf.reduce_logsumexp(x, axis=0)
    return moving_log_mean_exp.assign_add(update)
Exemplo n.º 24
0
def moving_mean_variance_zero_debiased(moving_mean, moving_variance=None,
                                       zero_debias_count=None, decay=0.99,
                                       name=None):
  """Compute zero debiased versions of `moving_mean` and `moving_variance`.

  Since `moving_*` variables initialized with `0`s will be biased (toward `0`),
  this function rescales the `moving_mean` and `moving_variance` by the factor
  `1 - decay**zero_debias_count`, i.e., such that the `moving_mean` is unbiased.
  For more details, see [Kingma (2014)][1].

  Args:
    moving_mean: `float`-like `tf.Variable` representing the exponentially
      weighted moving mean. Same shape as `moving_variance` and `value`. This
      function presumes the `tf.Variable` was created with all zero initial
      value(s).
    moving_variance: `float`-like `tf.Variable` representing the exponentially
      weighted moving variance. Same shape as `moving_mean` and `value`.  This
      function presumes the `tf.Variable` was created with all zero initial
      value(s).
      Default value: `None` (i.e., no moving variance is computed).
    zero_debias_count: `int`-like `tf.Variable` representing the number of times
      this function has been called on streaming input (*not* the number of
      reduced values used in this functions computation). When not `None` (the
      default) the returned values for `moving_mean` and `moving_variance` are
      "zero debiased", i.e., corrected for their presumed all zeros
      intialization. Note: the `tf.Variable`s `moving_mean` and
      `moving_variance` *always* store the unbiased calculation, regardless of
      setting this argument. To obtain unbiased calculations from these
      `tf.Variable`s, see `tfp.stats.moving_mean_variance_zero_debiased`.
      Default value: `None` (i.e., no zero debiasing calculation is made).
    decay: A `float`-like `Tensor` representing the moving mean decay. Typically
      close to `1.`, e.g., `0.99`.
      Default value: `0.99`.
    name: Python `str` prepended to op names created by this function.
      Default value: `None` (i.e., 'moving_mean_variance_zero_debiased').

  Returns:
    moving_mean: The zero debiased exponentially weighted moving mean.
    moving_variance: The zero debiased exponentially weighted moving variance.

  Raises:
    TypeError: if `moving_mean` does not have float type `dtype`.
    TypeError: if `moving_mean`, `moving_variance`, `decay` have different
      `base_dtype`.

  #### References

  [1]: Diederik P. Kingma, Jimmy Ba. Adam: A Method for Stochastic Optimization.
        _arXiv preprint arXiv:1412.6980_, 2014.
       https://arxiv.org/abs/1412.6980
  """
  with tf.name_scope(name or 'zero_debias_count'):
    if zero_debias_count is None:
      raise ValueError()
    base_dtype = dtype_util.base_dtype(moving_mean.dtype)
    if not dtype_util.is_floating(base_dtype):
      raise TypeError(
          'Argument `moving_mean` is not float type (saw {}).'.format(
              dtype_util.name(moving_mean.dtype)))
    t = tf.cast(zero_debias_count, dtype=base_dtype)
    # Could have used:
    #   bias_correction = -tf.math.expm1(t * tf.math.log(decay))
    # however since we expect decay to be nearly 1, we don't expect this to bear
    # a significant improvement, yet would incur higher computational cost.
    t = tf.where(t > 0., t, tf.constant(np.inf, base_dtype))
    bias_correction = 1. - decay**t
    unbiased_mean = moving_mean / bias_correction
    if moving_variance is None:
      return unbiased_mean
    if base_dtype != dtype_util.base_dtype(moving_variance.dtype):
      raise TypeError('Arguments `moving_mean` and `moving_variance` do not '
                      'have same base `dtype` (saw {}, {}).'.format(
                          dtype_util.name(moving_mean.dtype),
                          dtype_util.name(moving_variance.dtype)))
    unbiased_variance = moving_variance / bias_correction
    return unbiased_mean, unbiased_variance
Exemplo n.º 25
0
def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None):
    """Converts the given `value` to a `Tensor` if input is nonreference type.

  This function converts Python objects of various types to `Tensor` objects
  only if the input has nonreference semantics. Reference semantics are
  characterized by `tensor_util.is_ref` and is any object which is a
  `tf.Variable` or instance of `tf.Module`. This function accepts any input
  which `tf.convert_to_tensor` would also.

  Note: This function diverges from default Numpy behavior for `float` and
    `string` types when `None` is present in a Python list or scalar. Rather
    than silently converting `None` values, an error will be thrown.

  Args:
    value: An object whose type has a registered `Tensor` conversion function.
    dtype: Optional element type for the returned tensor. If missing, the
      type is inferred from the type of `value`.
    dtype_hint: Optional element type for the returned tensor,
      used when dtype is None. In some cases, a caller may not have a
      dtype in mind when converting to a tensor, so dtype_hint
      can be used as a soft preference.  If the conversion to
      `dtype_hint` is not possible, this argument has no effect.
    name: Optional name to use if a new `Tensor` is created.

  Returns:
    tensor: A `Tensor` based on `value`.

  Raises:
    TypeError: If no conversion function is registered for `value` to `dtype`.
    RuntimeError: If a registered conversion function returns an invalid value.
    ValueError: If the `value` is a tensor not of given `dtype` in graph mode.


  #### Examples:

  ```python
  from tensorflow_probability.python.internal import tensor_util

  x = tf.Variable(0.)
  y = tensor_util.convert_nonref_to_tensor(x)
  x is y
  # ==> True

  x = tf.constant(0.)
  y = tensor_util.convert_nonref_to_tensor(x)
  x is y
  # ==> True

  x = np.array(0.)
  y = tensor_util.convert_nonref_to_tensor(x)
  x is y
  # ==> False
  tf.is_tensor(y)
  # ==> True

  x = tfp.util.DeferredTensor(13.37, lambda x: x)
  y = tensor_util.convert_nonref_to_tensor(x)
  x is y
  # ==> True
  tf.is_tensor(y)
  # ==> True
  tf.equal(y, 13.37)
  # ==> True
  ```

  """
    # We explicitly do not use a tf.name_scope to avoid graph clutter.
    if value is None:
        return None
    if is_ref(value):
        if dtype is None:
            return value
        dtype_base = dtype_util.base_dtype(dtype)
        value_dtype_base = dtype_util.base_dtype(value.dtype)
        if dtype_base != value_dtype_base:
            raise TypeError(
                'Mutable type must be of dtype "{}" but is "{}".'.format(
                    dtype_util.name(dtype_base),
                    dtype_util.name(value_dtype_base)))
        return value
    return tf.convert_to_tensor(value,
                                dtype=dtype,
                                dtype_hint=dtype_hint,
                                name=name)
Exemplo n.º 26
0
    def __init__(self,
                 loc,
                 scale,
                 quadrature_size=8,
                 quadrature_fn=quadrature_scheme_lognormal_quantiles,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="PoissonLogNormalQuadratureCompound"):
        """Constructs the PoissonLogNormalQuadratureCompound`.

    Note: `probs` returned by (optional) `quadrature_fn` are presumed to be
    either a length-`quadrature_size` vector or a batch of vectors in 1-to-1
    correspondence with the returned `grid`. (I.e., broadcasting is only
    partially supported.)

    Args:
      loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
        the LogNormal prior.
      scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
        the LogNormal prior.
      quadrature_size: Python `int` scalar representing the number of quadrature
        points.
      quadrature_fn: Python callable taking `loc`, `scale`,
        `quadrature_size`, `validate_args` and returning `tuple(grid, probs)`
        representing the LogNormal grid and corresponding normalized weight.
        normalized) weight.
        Default value: `quadrature_scheme_lognormal_quantiles`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `quadrature_grid` and `quadrature_probs` have different base
        `dtype`.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale], tf.float32)
            if loc is not None:
                loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype)
            if scale is not None:
                scale = tf.convert_to_tensor(scale, dtype=dtype, name="scale")
            self._quadrature_grid, self._quadrature_probs = tuple(
                quadrature_fn(loc, scale, quadrature_size, validate_args))

            dt = self._quadrature_grid.dtype
            if not dtype_util.base_equal(dt, self._quadrature_probs.dtype):
                raise TypeError(
                    "Quadrature grid dtype ({}) does not match quadrature "
                    "probs dtype ({}).".format(
                        dtype_util.name(dt),
                        dtype_util.name(self._quadrature_probs.dtype)))

            self._distribution = poisson.Poisson(
                log_rate=self._quadrature_grid,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            self._mixture_distribution = categorical.Categorical(
                logits=tf.math.log(self._quadrature_probs),
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats)

            self._loc = loc
            self._scale = scale
            self._quadrature_size = quadrature_size

            super(PoissonLogNormalQuadratureCompound, self).__init__(
                dtype=dt,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=[loc, scale],
                name=name)
 def _str(s):
     if s is None:
         return '?'
     return dtype_util.name(s)
Exemplo n.º 28
0
def sample_lkj(num_samples,
               dimension,
               concentration,
               cholesky_space=False,
               seed=None,
               name=None):
    """Returns a Tensor of samples from an LKJ distribution.

  Args:
    num_samples: Python `int`. The number of samples to draw.
    dimension: Python `int`. The dimension of correlation matrices.
    concentration: `Tensor` representing the concentration of the LKJ
      distribution.
    cholesky_space: Python `bool`. Whether to take samples from LKJ or
      Chol(LKJ).
    seed: Python integer seed for RNG
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    samples: A Tensor of correlation matrices (or Cholesky factors of
      correlation matrices if `cholesky_space = True`) with shape
      `[n] + B + [D, D]`, where `B` is the shape of the `concentration`
      parameter, and `D` is the `dimension`.

  Raises:
    ValueError: If `dimension` is negative.
  """
    if dimension < 0:
        raise ValueError(
            'Cannot sample negative-dimension correlation matrices.')
    # Notation below: B is the batch shape, i.e., tf.shape(concentration)

    # We need 1 seed for beta corr12, and 2 per loop iter.
    num_seeds = 1 + 2 * max(0, dimension - 2)
    seeds = list(samplers.split_seed(seed, n=num_seeds, salt='sample_lkj'))
    with tf.name_scope('sample_lkj' or name):
        concentration = tf.convert_to_tensor(concentration)
        if not dtype_util.is_floating(concentration.dtype):
            raise TypeError(
                'The concentration argument should have floating type, not '
                '{}'.format(dtype_util.name(concentration.dtype)))

        concentration = _replicate(num_samples, concentration)
        concentration_shape = tf.shape(concentration)
        if dimension <= 1:
            # For any dimension <= 1, there is only one possible correlation matrix.
            shape = tf.concat([concentration_shape, [dimension, dimension]],
                              axis=0)
            return tf.ones(shape=shape, dtype=concentration.dtype)
        beta_conc = concentration + (dimension - 2.) / 2.
        beta_dist = beta.Beta(concentration1=beta_conc,
                              concentration0=beta_conc)

        # Note that the sampler below deviates from [1], by doing the sampling in
        # cholesky space. This does not change the fundamental logic of the
        # sampler, but does speed up the sampling.

        # This is the correlation coefficient between the first two dimensions.
        # This is also `r` in reference [1].
        corr12 = 2. * beta_dist.sample(seed=seeds.pop()) - 1.

        # Below we construct the Cholesky of the initial 2x2 correlation matrix,
        # which is of the form:
        # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the
        # first two dimensions.
        # This is the top-left corner of the cholesky of the final sample.
        first_row = tf.concat([
            tf.ones_like(corr12)[..., tf.newaxis],
            tf.zeros_like(corr12)[..., tf.newaxis]
        ],
                              axis=-1)
        second_row = tf.concat(
            [corr12[..., tf.newaxis],
             tf.sqrt(1 - corr12**2)[..., tf.newaxis]],
            axis=-1)

        chol_result = tf.concat(
            [first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :]],
            axis=-2)

        for n in range(2, dimension):
            # Loop invariant: on entry, result has shape B + [n, n]
            beta_conc = beta_conc - 0.5
            # norm is y in reference [1].
            norm = beta.Beta(concentration1=n / 2.,
                             concentration0=beta_conc).sample(seed=seeds.pop())
            # distance shape: B + [1] for broadcast
            distance = tf.sqrt(norm)[..., tf.newaxis]
            # direction is u in reference [1].
            # direction shape: B + [n]
            direction = _uniform_unit_norm(n,
                                           concentration_shape,
                                           concentration.dtype,
                                           seed=seeds.pop())
            # raw_correlation is w in reference [1].
            raw_correlation = distance * direction  # shape: B + [n]

            # This is the next row in the cholesky of the result,
            # which differs from the construction in reference [1].
            # In the reference, the new row `z` = chol_result @ raw_correlation^T
            # = C @ raw_correlation^T (where as short hand we use C = chol_result).
            # We prove that the below equation is the right row to add to the
            # cholesky, by showing equality with reference [1].
            # Let S be the sample constructed so far, and let `z` be as in
            # reference [1]. Then at this iteration, the new sample S' will be
            # [[S z^T]
            #  [z 1]]
            # In our case we have the cholesky decomposition factor C, so
            # we want our new row x (same size as z) to satisfy:
            #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
            #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
            # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
            # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
            # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
            # distance**2).
            new_row = tf.concat(
                [raw_correlation,
                 tf.sqrt(1. - norm[..., tf.newaxis])],
                axis=-1)

            # Finally add this new row, by growing the cholesky of the result.
            chol_result = tf.concat([
                chol_result,
                tf.zeros_like(chol_result[..., 0][..., tf.newaxis])
            ],
                                    axis=-1)

            chol_result = tf.concat([chol_result, new_row[..., tf.newaxis, :]],
                                    axis=-2)

        assert not seeds, 'Did not use all seeds: ' + len(seeds)
        if cholesky_space:
            return chol_result

        result = tf.matmul(chol_result, chol_result, transpose_b=True)
        # The diagonal for a correlation matrix should always be ones. Due to
        # numerical instability the matmul might not achieve that, so manually set
        # these to ones.
        result = tf.linalg.set_diag(
            result, tf.ones(shape=tf.shape(result)[:-1], dtype=result.dtype))
        # This sampling algorithm can produce near-PSD matrices on which standard
        # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals`
        # fail. Specifically, as documented in b/116828694, around 2% of trials
        # of 900,000 5x5 matrices (distributed according to 9 different
        # concentration parameter values) contained at least one matrix on which
        # the Cholesky decomposition failed.
        return result
Exemplo n.º 29
0
def assign_moving_mean_variance(value, moving_mean, moving_variance=None,
                                zero_debias_count=None, decay=0.99, axis=(),
                                name=None):
  """Compute one update to the exponentially weighted moving mean and variance.

  The `value` updated exponentially weighted moving `moving_mean` and
  `moving_variance` are conceptually given by the following recurrence
  relations ([Welford (1962)][1]):

  ```python
  new_mean = old_mean + (1 - decay) * (value - old_mean)
  new_var  = old_var  + (1 - decay) * (value - old_mean) * (value - new_mean)
  ```

  This function implements the above recurrences in a numerically stable manner
  and also uses the `assign_add` op to allow concurrent lockless updates to the
  supplied variables.

  For additional references see [this John D. Cook blog post][
  https://www.johndcook.com/blog/standard_deviation/]
  (whereas we use `1 - decay = 1 / k`) and
  [Finch (2009; Eq.  143)][2] (whereas we use `1 - decay = alpha`).

  Since variables that are initialized to a `0` value will be `0` biased,
  providing `zero_debias_count` triggers scaling the `moving_mean` and
  `moving_variance` by the factor of `1 - decay ** (zero_debias_count + 1)`.
  For more details, see `tfp.stats.moving_mean_variance_zero_debiased`.

  Args:
    value: `float`-like `Tensor` representing one or more streaming
      observations. When `axis` is non-empty `value ` is reduced (by mean) for
      updated `moving_mean` and `moving-variance`. Presumed to have same shape
      as `moving_mean` and `moving_variance`.
    moving_mean: `float`-like `tf.Variable` representing the exponentially
      weighted moving mean. Same shape as `moving_variance` and `value`. This
      function presumes the `tf.Variable` was created with all zero initial
      value(s).
    moving_variance: `float`-like `tf.Variable` representing the exponentially
      weighted moving variance. Same shape as `moving_mean` and `value`.  This
      function presumes the `tf.Variable` was created with all zero initial
      value(s).
      Default value: `None` (i.e., no moving variance is computed).
    zero_debias_count: `int`-like `tf.Variable` representing the number of times
      this function has been called on streaming input (*not* the number of
      reduced values used in this functions computation). When not `None` (the
      default) the returned values for `moving_mean` and `moving_variance` are
      "zero debiased", i.e., corrected for their presumed all zeros
      intialization. Note: the `tf.Variable`s `moving_mean` and
      `moving_variance` *always* store the unbiased calculation, regardless of
      setting this argument. To obtain unbiased calculations from these
      `tf.Variable`s, see `tfp.stats.moving_mean_variance_zero_debiased`.
      Default value: `None` (i.e., no zero debiasing calculation is made).
    decay: A `float`-like `Tensor` representing the moving mean decay. Typically
      close to `1.`, e.g., `0.99`.
      Default value: `0.99`.
    axis: The dimensions to reduce. If `()` (the default) no dimensions are
      reduced. If `None` all dimensions are reduced. Must be in the range
      `[-rank(value), rank(value))`.
      Default value: `()` (i.e., no reduction is made).
    name: Python `str` prepended to op names created by this function.
      Default value: `None` (i.e., 'assign_moving_mean_variance').

  Returns:
    moving_mean: The `value`-updated exponentially weighted moving mean.
      Debiased if `zero_debias_count is not None`.
    moving_variance: The `value`-updated exponentially weighted moving variance.
      Debiased if `zero_debias_count is not None`.

  Raises:
    TypeError: if `moving_mean` does not have float type `dtype`.
    TypeError: if `moving_mean`, `moving_variance`, `value`, `decay` have
      different `base_dtype`.

  #### Examples

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  tfd = tfp.distributions
  d = tfd.MultivariateNormalTriL(
      loc=[-1., 1.],
      scale_tril=tf.linalg.cholesky([[0.75, 0.05],
                                     [0.05, 0.5]]))
  d.mean()
  # ==> [-1.,  1.]
  d.variance()
  # ==> [0.75, 0.5]
  moving_mean = tf.Variable(tf.zeros(2))
  moving_variance = tf.Variable(tf.zeros(2))
  zero_debias_count = tf.Variable(0)
  for _ in range(100):
    m, v = tfp.stats.assign_moving_mean_variance(
      value=d.sample(3),
      moving_mean=moving_mean,
      moving_variance=moving_variance,
      zero_debias_count=zero_debias_count,
      decay=0.99,
      axis=-2)
    print(m.numpy(), v.numpy())
  # ==> [-1.0334632  0.9545268] [0.8126194 0.5118788]
  # ==> [-1.0293456   0.96070296] [0.8115873  0.50947404]
  # ...
  # ==> [-1.025172  0.96351 ] [0.7142789  0.48570773]

  m1, v1 = tfp.stats.moving_mean_variance_zero_debiased(
    moving_mean,
    moving_variance,
    zero_debias_count,
    decay=0.99)
  print(m.numpy(), v.numpy())
  # ==> [-1.025172  0.96351 ] [0.7142789  0.48570773]
  assert(all(m == m1))
  assert(all(v == v1))
  ```

  #### References

  [1]  B. P. Welford. Note on a Method for Calculating Corrected Sums of
       Squares and Products. Technometrics, Vol. 4, No. 3 (Aug., 1962), p419-20.
       http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.302.7503&rep=rep1&type=pdf
       http://www.jstor.org/stable/1266577

  [2]: Tony Finch. Incremental calculation of weighted mean and variance.
       _Technical Report_, 2009.
       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
  """
  with tf.name_scope(name or 'assign_moving_mean_variance'):
    base_dtype = dtype_util.base_dtype(moving_mean.dtype)
    if not dtype_util.is_floating(base_dtype):
      raise TypeError(
          'Argument `moving_mean` is not float type (saw {}).'.format(
              dtype_util.name(moving_mean.dtype)))

    value = tf.convert_to_tensor(value, dtype=base_dtype, name='value')
    decay = tf.convert_to_tensor(decay, dtype=base_dtype, name='decay')
    # Force a read of `moving_mean` as  we'll need it twice.
    old_mean = tf.convert_to_tensor(
        moving_mean, dtype=base_dtype, name='old_mean')

    updated_mean = moving_mean.assign_add(
        (1. - decay) * (tf.reduce_mean(value, axis=axis) - old_mean))

    if zero_debias_count is not None:
      t = tf.cast(zero_debias_count.assign_add(1), base_dtype)
      # Could have used:
      #   bias_correction = -tf.math.expm1(t * tf.math.log(decay))
      # however since we expect decay to be nearly 1, we don't expect this to
      # bear a significant improvement, yet would incur higher computational
      # cost.
      bias_correction = 1. - decay**t
      with tf.control_dependencies([updated_mean]):
        updated_mean = updated_mean / bias_correction

    if moving_variance is None:
      return updated_mean

    if base_dtype != dtype_util.base_dtype(moving_variance.dtype):
      raise TypeError('Arguments `moving_mean` and `moving_variance` do not '
                      'have same base `dtype` (saw {}, {}).'.format(
                          dtype_util.name(moving_mean.dtype),
                          dtype_util.name(moving_variance.dtype)))

    if zero_debias_count is not None:
      old_t = tf.where(t > 1., t - 1., tf.constant(np.inf, base_dtype))
      old_bias_correction = 1. - decay**old_t
      old_mean = old_mean / old_bias_correction

    mean_sq_diff = tf.reduce_mean(
        tf.math.squared_difference(value, old_mean),
        axis=axis)
    updated_variance = moving_variance.assign_add(
        (1. - decay) * (decay * mean_sq_diff - moving_variance))

    if zero_debias_count is not None:
      with tf.control_dependencies([updated_variance]):
        updated_variance = updated_variance / bias_correction

    return updated_mean, updated_variance
Exemplo n.º 30
0
  def __init__(self,
               rate=None,
               log_rate=None,
               force_probs_to_zero_outside_support=None,
               interpolate_nondiscrete=True,
               validate_args=False,
               allow_nan_stats=True,
               name='Poisson'):
    """Initialize a batch of Poisson distributions.

    Args:
      rate: Floating point tensor, the rate parameter. `rate` must be positive.
        Must specify exactly one of `rate` and `log_rate`.
      log_rate: Floating point tensor, the log of the rate parameter.
        Must specify exactly one of `rate` and `log_rate`.
      force_probs_to_zero_outside_support: Python `bool`. When `True`, negative
        and non-integer values are evaluated "strictly": `log_prob` returns
        `-inf`, `prob` returns `0`, and `cdf` and `sf` correspond.  When
        `False`, the implementation is free to save computation (and TF graph
        size) by evaluating something that matches the Poisson pmf at integer
        values `k` but produces an unrestricted result on other inputs.  In the
        case of Poisson, the `log_prob` formula in this case happens to be the
        continuous function `k * log_rate - lgamma(k+1) - rate`.  Note that this
        function is not itself a normalized probability log-density.
        Default value: `False`.
      interpolate_nondiscrete: Deprecated.  Use
        `force_probs_to_zero_outside_support` (with the opposite sense) instead.
        Python `bool`. When `False`, `log_prob` returns `-inf` (and `prob`
        returns `0`) for non-integer inputs. When `True`, `log_prob` evaluates
        the continuous function `k * log_rate - lgamma(k+1) - rate`, which
        matches the Poisson pmf at integer arguments `k` (note that this
        function is not itself a normalized probability log-density).
        Default value: `True`.
      validate_args: Python `bool`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if none or both of `rate`, `log_rate` are specified.
      TypeError: if `rate` is not a float-type.
      TypeError: if `log_rate` is not a float-type.
    """
    parameters = dict(locals())
    if (rate is None) == (log_rate is None):
      raise ValueError('Must specify exactly one of `rate` and `log_rate`.')
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype([rate, log_rate], dtype_hint=tf.float32)
      if not dtype_util.is_floating(dtype):
        raise TypeError('[log_]rate.dtype ({}) is a not a float-type.'.format(
            dtype_util.name(dtype)))
      self._rate = tensor_util.convert_nonref_to_tensor(
          rate, name='rate', dtype=dtype)
      self._log_rate = tensor_util.convert_nonref_to_tensor(
          log_rate, name='log_rate', dtype=dtype)

      self._interpolate_nondiscrete = interpolate_nondiscrete
      if force_probs_to_zero_outside_support is not None:
        # `force_probs_to_zero_outside_support` was explicitly set, so it
        # controls.
        self._force_probs_to_zero_outside_support = (
            force_probs_to_zero_outside_support)
      elif not self._interpolate_nondiscrete:
        # `interpolate_nondiscrete` was explicitly set by the caller, so it
        # should control until it is removed.
        self._force_probs_to_zero_outside_support = True
      else:
        # Default.
        self._force_probs_to_zero_outside_support = False
      super(Poisson, self).__init__(
          dtype=dtype,
          reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          name=name)