예제 #1
0
 def _calculate_batch_shape(self):
   """Computes fully defined batch shape for the new distribution."""
   all_batch_shapes = [d.batch_shape.as_list()
                       if tensorshape_util.is_fully_defined(d.batch_shape)
                       else d.batch_shape_tensor() for d in self.distributions]
   original_shape = ps.stack(all_batch_shapes, axis=0)
   index_mask = ps.cast(
       ps.one_hot(self._axis, ps.shape(original_shape)[1]),
       dtype=tf.bool)
   new_concat_dim = ps.cast(
       ps.reduce_sum(original_shape, axis=0)[self._axis], dtype=tf.int32)
   return ps.where(index_mask, new_concat_dim,
                   ps.reduce_max(original_shape, axis=0))
예제 #2
0
def initial_value_of_masked_time_series(time_series_tensor, broadcast_mask):
    """Get the first unmasked entry of each time series in the batch.

  If a batch element has no unmasked entries, the corresponding return value
  for that element is undefined.

  Args:
    time_series_tensor: float `Tensor` of shape `batch_shape + [num_timesteps]`.
    broadcast_mask: bool `Tensor` of same shape as `time_series`.
  Returns:
    initial_values: float `Tensor` of shape `batch_shape`.
  """

    num_timesteps = ps.shape(time_series_tensor)[-1]

    # Compute the index of the first unmasked entry for each series in the batch.
    unmasked_negindices = (ps.cast(~broadcast_mask, np.int32) *
                           ps.range(num_timesteps, 0, -1))
    first_unmasked_indices = num_timesteps - ps.reduce_max(unmasked_negindices,
                                                           axis=-1)
    # Avoid out-of-bounds errors if all indices are masked.
    safe_unmasked_indices = ps.minimum(first_unmasked_indices,
                                       num_timesteps - 1)

    batch_dims = tensorshape_util.rank(safe_unmasked_indices.shape)
    if batch_dims is None:
        raise NotImplementedError(
            'Cannot compute initial values of a masked time series with'
            'dynamic rank.')  # `batch_gather` requires static rank

    # Extract the initial value for each series in the batch.
    return tf.squeeze(
        tf.gather(params=time_series_tensor,
                  indices=safe_unmasked_indices[..., np.newaxis],
                  batch_dims=batch_dims,
                  axis=-1),
        # Since we've gathered exactly one step from the
        # `num_timesteps` axis, we can remove that axis entirely.
        axis=-1)
예제 #3
0
def _update_forward_min_event_ndims(
    bij,
    downstream_quantities,
    get_forward_min_event_ndims=lambda b: b.forward_min_event_ndims,
    get_inverse_min_event_ndims=lambda b: b.inverse_min_event_ndims,
    inverse_event_ndims_fn=lambda b, nd: b.inverse_event_ndims(nd)):
  """Step backwards through the graph to infer `forward_min_event_ndims`.

  Args:
    bij: local tfb.Bijector instance at the current graph node.
    downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities`
      namedtuple, containing event_ndims that satisfy the bijector(s)
      downstream from `bij` in the graph. May be `None` if there are no such
      bijectors.
    get_forward_min_event_ndims: callable; may be overridden to swap
      forward/inverse direction.
    get_inverse_min_event_ndims: callable; may be overridden to swap
      forward/inverse direction.
    inverse_event_ndims_fn: callable; may be overridden to swap
      forward/inverse direction.
  Returns:
    downstream_quantities: Instance of `MinEventNdimsDownstreamQuantities`
      namedtuple containing event_ndims that satisfy `bij` and all downstream
      bijectors.
  """
  if downstream_quantities is None:  # This is a leaf bijector.
    return MinEventNdimsInferenceDownstreamQuantities(
        forward_min_event_ndims=get_forward_min_event_ndims(bij),
        parts_interact=bij._parts_interact)  # pylint: disable=protected-access

  inverse_min_event_ndims = get_inverse_min_event_ndims(bij)
  downstream_min_event_ndims = nest_util.coerce_structure(
      inverse_min_event_ndims,
      downstream_quantities.forward_min_event_ndims)

  # Update the min_event_ndims that is a valid input to downstream bijectors
  # to also be a valid *output* of this bijector, or equivalently, a valid
  # input to `bij.inverse`.
  rank_mismatches = tf.nest.flatten(
      tf.nest.map_structure(
          lambda dim, min_dim: dim - min_dim,
          downstream_min_event_ndims,
          inverse_min_event_ndims))
  if downstream_quantities.parts_interact:
    # If downstream bijectors involve interaction between parts,
    # then a valid input to the downstream bijectors must augment the
    # `downstream_min_event_ndims` by the
    # same rank for every part (otherwise we would induce event shape
    # broadcasting). Hopefully, this will also avoid event-shape broadcasting
    # at the current bijector---if not, the composition is invalid, and the call
    # to `bij.inverse_event_ndims(valid_inverse_min_event_ndims)` below will
    # raise an exception.
    maximum_rank_deficiency = -ps.reduce_min([0] + rank_mismatches)
    valid_inverse_min_event_ndims = tf.nest.map_structure(
        lambda ndims: maximum_rank_deficiency + ndims,
        downstream_min_event_ndims)
  else:
    if bij._parts_interact:  # pylint: disable=protected-access
      # If this bijector does *not* operate independently on its parts, then a
      # valid input to `inverse` cannot require event shape broadcasting. That
      # is, each part must have the same 'excess rank' above the local
      # inverse_min_event_ndims; we ensure this by construction.
      maximum_excess_rank = ps.reduce_max([0] + rank_mismatches)
      valid_inverse_min_event_ndims = tf.nest.map_structure(
          lambda ndims: maximum_excess_rank + ndims,
          inverse_min_event_ndims)
    else:
      # If all parts are independent, can take the pointwise max event_ndims.
      valid_inverse_min_event_ndims = tf.nest.map_structure(
          ps.maximum, downstream_min_event_ndims, inverse_min_event_ndims)

  return MinEventNdimsInferenceDownstreamQuantities(
      # Pull the desired output ndims back through the bijector, to get
      # the ndims of a valid *input*.
      forward_min_event_ndims=inverse_event_ndims_fn(
          bij, valid_inverse_min_event_ndims),
      parts_interact=(
          downstream_quantities.parts_interact or
          bij._parts_interact))  # pylint: disable=protected-access
예제 #4
0
    def __init__(self,
                 skewness,
                 tailweight,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Johnson's SU distributions.

    The distributions have shape parameteres `tailweight` and `skewness`,
    mean `loc`, and scale `scale`.

    The parameters `tailweight`, `skewness`, `loc`, and `scale` must be shaped
    in a way that supports broadcasting
    (e.g. `skewness + tailweight + loc + scale` is a valid operation).

    Args:
      skewness: Floating-point `Tensor`. Skewness of the distribution(s).
      tailweight: Floating-point `Tensor`. Tail weight of the
        distribution(s). `tailweight` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value '`NaN`' to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if any of skewness, tailweight, loc and scale are different
        dtypes.
    """
        parameters = dict(locals())
        with tf.name_scope(name or 'JohnsonSU') as name:
            dtype = dtype_util.common_dtype([skewness, tailweight, loc, scale],
                                            tf.float32)
            self._skewness = tensor_util.convert_nonref_to_tensor(
                skewness, name='skewness', dtype=dtype)
            self._tailweight = tensor_util.convert_nonref_to_tensor(
                tailweight, name='tailweight', dtype=dtype)
            self._loc = tensor_util.convert_nonref_to_tensor(loc,
                                                             name='loc',
                                                             dtype=dtype)
            self._scale = tensor_util.convert_nonref_to_tensor(scale,
                                                               name='scale',
                                                               dtype=dtype)

            norm_shift = invert_bijector.Invert(
                shift_bijector.Shift(shift=self._skewness,
                                     validate_args=validate_args))

            norm_scale = invert_bijector.Invert(
                scale_bijector.Scale(scale=self._tailweight,
                                     validate_args=validate_args))

            sinh = sinh_bijector.Sinh(validate_args=validate_args)

            scale = scale_bijector.Scale(scale=self._scale,
                                         validate_args=validate_args)

            shift = shift_bijector.Shift(shift=self._loc,
                                         validate_args=validate_args)

            bijector = shift(scale(sinh(norm_scale(norm_shift))))

            batch_rank = ps.reduce_max([
                distribution_util.prefer_static_rank(x)
                for x in (self._skewness, self._tailweight, self._loc,
                          self._scale)
            ])

            super(JohnsonSU, self).__init__(
                # TODO(b/160730249): Make `loc` a scalar `0.` and remove overridden
                # `batch_shape` and `batch_shape_tensor` when
                # TransformedDistribution's bijector can modify its `batch_shape`.
                distribution=normal.Normal(loc=tf.zeros(ps.ones(
                    batch_rank, tf.int32),
                                                        dtype=dtype),
                                           scale=tf.ones([], dtype=dtype),
                                           validate_args=validate_args,
                                           allow_nan_stats=allow_nan_stats),
                bijector=bijector,
                validate_args=validate_args,
                parameters=parameters,
                name=name)