def _fn(x):
     """MADE parameterized via `masked_autoregressive_default_template`."""
     # TODO(b/67594795): Better support of dynamic shape.
     input_depth = tf.compat.dimension_value(
         tensorshape_util.with_rank_at_least(x.shape, 1)[-1])
     if input_depth is None:
         raise NotImplementedError(
             'Rightmost dimension must be known prior to graph execution.'
         )
     input_shape = (np.int32(tensorshape_util.as_list(x.shape))
                    if tensorshape_util.is_fully_defined(x.shape) else
                    tf.shape(x))
     if tensorshape_util.rank(x.shape) == 1:
         x = x[tf.newaxis, ...]
     for i, units in enumerate(hidden_layers):
         x = masked_dense(
             inputs=x,
             units=units,
             num_blocks=input_depth,
             exclusive=True if i == 0 else False,
             activation=activation,
             *args,  # pylint: disable=keyword-arg-before-vararg
             **kwargs)
     x = masked_dense(
         inputs=x,
         units=(1 if shift_only else 2) * input_depth,
         num_blocks=input_depth,
         activation=None,
         *args,  # pylint: disable=keyword-arg-before-vararg
         **kwargs)
     if shift_only:
         x = tf.reshape(x, shape=input_shape)
         return x, None
     x = tf.reshape(x, shape=tf.concat([input_shape, [2]], axis=0))
     shift, log_scale = tf.unstack(x, num=2, axis=-1)
     which_clip = (tf.clip_by_value if log_scale_clip_gradient else
                   clip_by_value_preserve_gradient)
     log_scale = which_clip(log_scale, log_scale_min_clip,
                            log_scale_max_clip)
     return shift, log_scale
Esempio n. 2
0
  def _parameter_control_dependencies(self, is_init):
    if tensorshape_util.is_fully_defined(self.distribution.batch_shape):
      if self.to_shape is not None:
        static_to_shape = tf.get_static_value(self.to_shape)
        if static_to_shape is not None:
          bcast_shp = tf.broadcast_static_shape(
              tf.TensorShape(static_to_shape),
              self.distribution.batch_shape)
          if bcast_shp != static_to_shape:
            raise ValueError(f'Argument `to_shape` ({static_to_shape}) '
                             'is incompatible with underlying distribution '
                             f'batch shape ({self.distribution.batch_shape}).')

      else:
        static_with_shape = tf.get_static_value(self.with_shape)
        if static_with_shape is not None:
          tf.broadcast_static_shape(  # Ensure compatible.
              tf.TensorShape(static_with_shape),
              self.distribution.batch_shape)

    underlying = self.distribution._parameter_control_dependencies(is_init)  # pylint: disable=protected-access
    if not self.validate_args:
      return underlying

    checks = []
    if self.to_shape is not None:
      if tensor_util.is_ref(self.to_shape) != is_init:
        checks += [assert_util.assert_equal(
            self.to_shape,
            ps.broadcast_shape(self.distribution.batch_shape_tensor(),
                               self.to_shape),
            message='Argument `to_shape` is incompatible with underlying '
                    'distribution batch shape.')]
    else:
      if tensor_util.is_ref(self.with_shape) != is_init:
        checks += [tf.broadcast_dynamic_shape(
            self.distribution.batch_shape_tensor(),
            self.with_shape)]

    return tuple(checks) + tuple(underlying)
Esempio n. 3
0
    def _flatten_and_concat_event(self, x):
        def _reshape_part(part, event_shape):
            part = tf.cast(part, self.dtype)
            static_rank = tf.get_static_value(ps.rank_from_shape(event_shape))
            if static_rank == 1:
                return part
            new_shape = ps.concat([
                ps.shape(part)[:ps.size(ps.shape(part)) -
                               ps.size(event_shape)], [-1]
            ],
                                  axis=-1)
            return tf.reshape(part, ps.cast(new_shape, tf.int32))

        if all(
                tensorshape_util.is_fully_defined(s)
                for s in tf.nest.flatten(self._distribution.event_shape)):
            x = tf.nest.map_structure(_reshape_part, x,
                                      self._distribution.event_shape)
        else:
            x = tf.nest.map_structure(_reshape_part, x,
                                      self._distribution.event_shape_tensor())
        return tf.concat(tf.nest.flatten(x), axis=-1)
Esempio n. 4
0
def _broadcast_cat_event_and_params(event, params, base_dtype):
    """Broadcasts the event or distribution parameters."""
    if dtype_util.is_integer(event.dtype):
        pass
    elif dtype_util.is_floating(event.dtype):
        # When `validate_args=True` we've already ensured int/float casting
        # is closed.
        event = tf.cast(event, dtype=tf.int32)
    else:
        raise TypeError("`value` should have integer `dtype` or "
                        "`self.dtype` ({})".format(base_dtype))
    shape_known_statically = (tensorshape_util.rank(params.shape) is not None
                              and params.shape[:-1].is_fully_defined() and
                              tensorshape_util.is_fully_defined(event.shape))
    if not shape_known_statically or params.shape[:-1] != event.shape:
        params *= tf.ones_like(event[..., tf.newaxis], dtype=params.dtype)
        params_shape = tf.shape(input=params)[:-1]
        event *= tf.ones(params_shape, dtype=event.dtype)
        if tensorshape_util.rank(params.shape) is not None:
            tensorshape_util.set_shape(event, params.shape[:-1])

    return event, params
Esempio n. 5
0
def _slice_params_to_dict(dist, params_event_ndims, slices):
  """Computes the override dictionary of sliced parameters.

  Args:
    dist: The tfd.Distribution being batch-sliced.
    params_event_ndims: Per-event parameter ranks, a `str->int` `dict`.
    slices: Slices as received by __getitem__.

  Returns:
    overrides: `str->Tensor` `dict` of batch-sliced parameter overrides.
  """
  override_dict = {}
  for param_name, param_event_ndims in params_event_ndims.items():
    # Verify that either None or a legit value is in the parameters dict.
    if param_name not in dist.parameters:
      raise ValueError('Distribution {} is missing advertised '
                       'parameter {}'.format(dist, param_name))
    param = dist.parameters[param_name]
    if param is None:
      # some distributions have multiple possible parameterizations; this
      # param was not provided
      continue
    dtype = None
    if hasattr(dist, param_name):
      attr = getattr(dist, param_name)
      dtype = getattr(attr, 'dtype', None)
    if dtype is None:
      dtype = dist.dtype
      warnings.warn('Unable to find property getter for parameter Tensor {} '
                    'on {}, falling back to Distribution.dtype {}'.format(
                        param_name, dist, dtype))
    param = tf.convert_to_tensor(value=param, dtype=dtype)
    dist_batch_shape = dist.batch_shape
    if not tensorshape_util.is_fully_defined(dist_batch_shape):
      dist_batch_shape = dist.batch_shape_tensor()
    override_dict[param_name] = _slice_single_param(param, param_event_ndims,
                                                    slices,
                                                    dist_batch_shape)
  return override_dict
Esempio n. 6
0
def _validate_block_sizes(block_sizes, bijectors, validate_args):
  """Helper to validate block sizes."""
  block_sizes_shape = block_sizes.shape
  if tensorshape_util.is_fully_defined(block_sizes_shape):
    if (tensorshape_util.rank(block_sizes_shape) != 1 or
        (tensorshape_util.num_elements(block_sizes_shape) != len(bijectors))):
      raise ValueError(
          '`block_sizes` must be `None`, or a vector of the same length as '
          '`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
          'length {}'.format(block_sizes_shape, len(bijectors)))
    return block_sizes
  elif validate_args:
    message = ('`block_sizes` must be `None`, or a vector of the same length '
               'as `bijectors`.')
    with tf.control_dependencies([
        assert_util.assert_equal(
            tf.size(block_sizes), len(bijectors), message=message),
        assert_util.assert_equal(tf.rank(block_sizes), 1)
    ]):
      return tf.identity(block_sizes)
  else:
    return block_sizes
Esempio n. 7
0
    def _split_and_reshape_event(self, x):
        assertions = []
        message = 'Input must have at least one dimension.'
        if tensorshape_util.rank(x.shape) is not None:
            if tensorshape_util.rank(x.shape) == 0:
                raise ValueError(message)
        elif self.validate_args:
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=message))
        with tf.control_dependencies(assertions):
            event_tensors = self._distribution.event_shape_tensor()
            splits = [
                ps.maximum(1, ps.reduce_prod(s))
                for s in tf.nest.flatten(event_tensors)
            ]
            x = tf.nest.pack_sequence_as(event_tensors,
                                         tf.split(x, splits, axis=-1))

            def _reshape_part(part, dtype, event_shape):
                part = tf.cast(part, dtype)
                static_rank = tf.get_static_value(
                    ps.rank_from_shape(event_shape))
                if static_rank == 1:
                    return part
                new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                                      axis=-1)
                return tf.reshape(part, ps.cast(new_shape, tf.int32))

            if all(
                    tensorshape_util.is_fully_defined(s)
                    for s in tf.nest.flatten(self._distribution.event_shape)):
                x = tf.nest.map_structure(_reshape_part, x,
                                          self._distribution.dtype,
                                          self._distribution.event_shape)
            else:
                x = tf.nest.map_structure(
                    _reshape_part, x, self._distribution.dtype,
                    self._distribution.event_shape_tensor())
        return x
Esempio n. 8
0
  def _sample_control_dependencies(self, x):
    assertions = []
    if tensorshape_util.is_fully_defined(x.shape[-2:]):
      if not (tensorshape_util.dims(x.shape)[-2] ==
              tensorshape_util.dims(x.shape)[-1] ==
              self.dimension):
        raise ValueError(
            'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
                self.dimension, self.dimension, tensorshape_util.dims(x.shape)))
    elif self.validate_args:
      msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
          self.dimension, self.dimension, tf.shape(x))
      assertions.append(assert_util.assert_equal(
          tf.shape(x)[-2], self.dimension, message=msg))
      assertions.append(assert_util.assert_equal(
          tf.shape(x)[-1], self.dimension, message=msg))

    if self.validate_args:
      assertions.append(assert_util.assert_near(
          x, tf.linalg.band_part(x, -1, 0),
          message='Cholesky factors must be lower triangular.'))
    return assertions
Esempio n. 9
0
 def _has_valid_dimensions(self, x):
     if tensorshape_util.is_fully_defined(x.shape[-2:]):
         if (tensorshape_util.dims(x.shape)[-2] == tensorshape_util.dims(
                 x.shape)[-1] == self.dimension):
             return []
         else:
             raise ValueError(
                 'Input dimension mismatch: expected [..., {}, {}], got {}'.
                 format(self.dimension, self.dimension,
                        tensorshape_util.dims(x.shape)))
     elif self.validate_args:
         msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
             self.dimension, self.dimension, tf.shape(x))
         return [
             assert_util.assert_equal(tf.shape(x)[-2],
                                      self.dimension,
                                      message=msg),
             assert_util.assert_equal(tf.shape(x)[-1],
                                      self.dimension,
                                      message=msg)
         ]
     return []
Esempio n. 10
0
 def _validate_dimension(self, x):
   x = tf.convert_to_tensor(x, name='x')
   if tensorshape_util.is_fully_defined(x.shape[-2:]):
     if (tensorshape_util.dims(x.shape)[-2] ==
         tensorshape_util.dims(x.shape)[-1] ==
         self.dimension):
       pass
     else:
       raise ValueError(
           'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
               self.dimension, self.dimension, tensorshape_util.dims(x.shape)))
   elif self.validate_args:
     msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
         self.dimension, self.dimension, tf.shape(x))
     with tf.control_dependencies([
         assert_util.assert_equal(
             tf.shape(x)[-2], self.dimension, message=msg),
         assert_util.assert_equal(
             tf.shape(x)[-1], self.dimension, message=msg)
     ]):
       x = tf.identity(x)
   return x
Esempio n. 11
0
  def _sample_control_dependencies(self, x):
    assertions = []
    if tensorshape_util.is_fully_defined(x.shape[-2:]):
      if not (tensorshape_util.dims(x.shape)[-2] ==
              tensorshape_util.dims(x.shape)[-1] ==
              self.dimension):
        raise ValueError(
            'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
                self.dimension, self.dimension, tensorshape_util.dims(x.shape)))
    elif self.validate_args:
      msg = 'Input dimension mismatch: expected [..., {}, {}], got {}'.format(
          self.dimension, self.dimension, tf.shape(x))
      assertions.append(assert_util.assert_equal(
          tf.shape(x)[-2], self.dimension, message=msg))
      assertions.append(assert_util.assert_equal(
          tf.shape(x)[-1], self.dimension, message=msg))

    if self.validate_args and not self.input_output_cholesky:
      assertions.append(assert_util.assert_less_equal(
          dtype_util.as_numpy_dtype(x.dtype)(-1),
          x,
          message='Correlations must be >= -1.',
          summarize=30))
      assertions.append(assert_util.assert_less_equal(
          x,
          dtype_util.as_numpy_dtype(x.dtype)(1),
          message='Correlations must be <= 1.',
          summarize=30))
      assertions.append(assert_util.assert_near(
          tf.linalg.diag_part(x),
          dtype_util.as_numpy_dtype(x.dtype)(1),
          message='Self-correlations must be = 1.',
          summarize=30))
      assertions.append(assert_util.assert_near(
          x,
          tf.linalg.matrix_transpose(x),
          message='Correlation matrices must be symmetric.',
          summarize=30))
    return assertions
Esempio n. 12
0
def _sub_diag(nonmatrix):
    """Get the first sub-diagonal of a shape [N, N, ...] 'non matrix'."""
    with tf.name_scope('sub_matrix'):
        # TODO(b/143702351) Once array_ops.matrix_diag_part_v3 is ready and exposed,
        # replace the call to matrix_diag_part_v2 below with tf.linalg.matrix_diag.
        # We can also stop special casing for matrix_dim < 2 at that point.
        # Until then, OpError raised for 1x1 matricies without static shape.
        # In fact, non-static shape breaks matrix_diag_part_v2, so we must raise
        # this message now.
        # See http://b/138403336 for the TF issue tracker.
        if not tensorshape_util.is_fully_defined(nonmatrix.shape[:2]):
            raise ValueError(
                '`inverse_temperatures did not have statically defined shape, '
                'which breaks tracking of is_swap_{proposed,accepted}.  '
                'Please provide an inverse_temperatures with statically known shape.'
            )

        # The sub-matrix of a 1x1 matrix is not defined (throws exception), so in
        # this special case return an empty matrix.
        # TODO(b/143702351) Remove this special case handling once
        # matrix_diag_part_v3 is ready.
        matrix_dim = ps.size0(nonmatrix)
        if matrix_dim is not None and matrix_dim < 2:
            # Shape is [..., 0], so returned tensor is empty, thus contains no
            # values...and therefore the fact that we use 'ones' doesn't matter.
            shape = ps.pad(ps.shape(nonmatrix)[2:],
                           paddings=[[0, 1]],
                           constant_values=0)
            matrix_sub_diag = tf.cast(tf.ones(shape), nonmatrix.dtype)

        else:
            # Get first sub-diagonal.  `padding_value` is not used (since matrix is
            # square), but is required for the API since this is raw gen_array_ops.
            matrix_sub_diag = tf.raw_ops.MatrixDiagPartV2(
                input=distribution_util.rotate_transpose(nonmatrix, shift=-2),
                k=ps.convert_to_shape_tensor(-1, dtype=tf.int32),
                padding_value=tf.cast(0.0, dtype=nonmatrix.dtype))

        return distribution_util.rotate_transpose(matrix_sub_diag, shift=1)
Esempio n. 13
0
def get_broadcast_shape(*tensors):
    """Get broadcast shape as a Python list of integers (preferred) or `Tensor`.

  Args:
    *tensors:  One or more `Tensor` objects (already converted!).

  Returns:
    broadcast shape:  Python list (if shapes determined statically), otherwise
      an `int32` `Tensor`.
  """
    # Try static.
    s_shape = tensors[0].shape
    for t in tensors[1:]:
        s_shape = tf.broadcast_static_shape(s_shape, t.shape)
    if tensorshape_util.is_fully_defined(s_shape):
        return tensorshape_util.as_list(s_shape)

    # Fallback on dynamic.
    d_shape = tf.shape(tensors[0])
    for t in tensors[1:]:
        d_shape = tf.broadcast_dynamic_shape(d_shape, tf.shape(t))
    return d_shape
Esempio n. 14
0
def _move_dims_to_flat_end(x, axis, x_ndims, right_end=True):
    """Move dims corresponding to `axis` in `x` to the end, then flatten.

  Args:
    x: `Tensor` with shape `[B0,B1,...,Bb]`.
    axis:  Python list of indices into dimensions of `x`.
    x_ndims:  Python integer holding number of dimensions in `x`.
    right_end:  Python bool.  Whether to move dims to the right end (else left).

  Returns:
    `Tensor` with value from `x` and dims in `axis` moved to end into one single
      dimension.
  """

    if not axis:
        return x

    # Suppose x.shape = [a, b, c, d]
    # Suppose axis = [1, 3]

    # other_dims = [0, 2] in example above.
    other_dims = sorted(set(range(x_ndims)).difference(axis))
    # x_permed.shape = [a, c, b, d]
    perm = other_dims + list(axis) if right_end else list(axis) + other_dims
    x_permed = tf.transpose(a=x, perm=perm)

    if tensorshape_util.is_fully_defined(x.shape):
        x_shape = tensorshape_util.as_list(x.shape)
        # other_shape = [a, c], end_shape = [b * d]
        other_shape = [x_shape[i] for i in other_dims]
        end_shape = [np.prod([x_shape[i] for i in axis])]
        full_shape = (other_shape + end_shape if right_end else end_shape +
                      other_shape)
    else:
        other_shape = ps.gather(ps.shape(x), ps.cast(other_dims, tf.int64))
        full_shape = ps.concat(
            [other_shape, [-1]] if right_end else [[-1], other_shape], axis=0)
    return tf.reshape(x_permed, shape=full_shape)
Esempio n. 15
0
 def _calculate_new_shape(self):
     # Try to get the old shape statically if available.
     original_shape = self._distribution.batch_shape
     if not tensorshape_util.is_fully_defined(original_shape):
         original_shape = self._distribution.batch_shape_tensor()
     # This is not a check for falseness, it's a check for exactly that shape.
     if original_shape == ():  # pylint: disable=g-explicit-bool-comparison
         # Force the size to be an integer, not a float, when the shape contains no
         # dtype information.
         original_size = 1
     else:
         original_size = ps.reduce_prod(original_shape)
     original_size = ps.cast(original_size, tf.int32)
     # Compute the new shape, filling in the `-1` dimension if present.
     new_shape = self._batch_shape_unexpanded
     implicit_dim_mask = ps.equal(new_shape, -1)
     size_implicit_dim = (original_size //
                          ps.maximum(1, -ps.reduce_prod(new_shape)))
     expanded_new_shape = ps.where(  # Assumes exactly one `-1`.
         implicit_dim_mask, size_implicit_dim, new_shape)
     # Return the original size on the side because one caller would otherwise
     # have to recompute it.
     return expanded_new_shape, original_size
  def event_shape_tensor(self, name='event_shape_tensor'):
    """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.

    Args:
      name: name to give to the op

    Returns:
      event_shape: `Tensor`.
    """
    with self._name_and_control_scope(name):
      if all([tensorshape_util.is_fully_defined(s)
              for s in nest.flatten(self.event_shape)]):
        event_shape = nest.map_structure_up_to(
            self.dtype,
            tensorshape_util.as_list,
            self.event_shape, check_types=False)
      else:
        event_shape = self._event_shape_tensor()
      return nest.map_structure_up_to(
          self.dtype,
          lambda s: tf.identity(  # pylint: disable=g-long-lambda
              tf.convert_to_tensor(s, dtype=tf.int32), name='event_shape'),
          event_shape, check_types=False)
Esempio n. 17
0
 def _sample_shape(self, x):
   """Computes graph and static `sample_shape`."""
   x_ndims = (
       tf.rank(x) if tensorshape_util.rank(x.shape) is None else
       tensorshape_util.rank(x.shape))
   event_ndims = (
       tf.size(self.event_shape_tensor())
       if tensorshape_util.rank(self.event_shape) is None else
       tensorshape_util.rank(self.event_shape))
   batch_ndims = (
       tf.size(self._batch_shape_unexpanded)
       if tensorshape_util.rank(self.batch_shape) is None else
       tensorshape_util.rank(self.batch_shape))
   sample_ndims = x_ndims - batch_ndims - event_ndims
   if isinstance(sample_ndims, int):
     static_sample_shape = x.shape[:sample_ndims]
   else:
     static_sample_shape = tf.TensorShape(None)
   if tensorshape_util.is_fully_defined(static_sample_shape):
     sample_shape = np.int32(static_sample_shape)
   else:
     sample_shape = tf.shape(x)[:sample_ndims]
   return sample_shape, static_sample_shape
  def param_static_shapes(cls, sample_shape):
    """param_shapes with static (i.e. `TensorShape`) shapes.

    This is a class method that describes what key/value arguments are required
    to instantiate the given `Distribution` so that a particular shape is
    returned for that instance's call to `sample()`. Assumes that the sample's
    shape is known statically.

    Subclasses should override class method `_param_shapes` to return
    constant-valued tensors when constant values are fed.

    Args:
      sample_shape: `TensorShape` or python list/tuple. Desired shape of a call
        to `sample()`.

    Returns:
      `dict` of parameter name to `TensorShape`.

    Raises:
      ValueError: if `sample_shape` is a `TensorShape` and is not fully defined.
    """
    if isinstance(sample_shape, tf.TensorShape):
      if not tensorshape_util.is_fully_defined(sample_shape):
        raise ValueError('TensorShape sample_shape must be fully defined')
      sample_shape = tensorshape_util.as_list(sample_shape)

    params = cls.param_shapes(sample_shape)

    static_params = {}
    for name, shape in params.items():
      static_shape = tf.get_static_value(shape)
      if static_shape is None:
        raise ValueError(
            'sample_shape must be a fully-defined TensorShape or list/tuple')
      static_params[name] = tf.TensorShape(static_shape)

    return static_params
Esempio n. 19
0
def _replace_event_shape_in_tensorshape(
    input_tensorshape, event_shape_in, event_shape_out):
  """Replaces the event shape dims of a `TensorShape`.

  Args:
    input_tensorshape: a `TensorShape` instance in which to attempt replacing
      event shape.
    event_shape_in: `Tensor` shape representing the event shape expected to
      be present in (rightmost dims of) `tensorshape_in`. Must be compatible
      with the rightmost dims of `tensorshape_in`.
    event_shape_out: `Tensor` shape representing the new event shape, i.e.,
      the replacement of `event_shape_in`,

  Returns:
    output_tensorshape: `TensorShape` with the rightmost `event_shape_in`
      replaced by `event_shape_out`. Might be partially defined, i.e.,
      `TensorShape(None)`.
    is_validated: Python `bool` indicating static validation happened.

  Raises:
    ValueError: if we can determine the event shape portion of
      `tensorshape_in` as well as `event_shape_in` both statically, and they
      are not compatible. "Compatible" here means that they are identical on
      any dims that are not -1 in `event_shape_in`.
  """
  event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape)
  if tensorshape_util.rank(
      input_tensorshape) is None or event_shape_in_ndims is None:
    return tf.TensorShape(None), False  # Not is_validated.

  input_non_event_ndims = tensorshape_util.rank(
      input_tensorshape) - event_shape_in_ndims
  if input_non_event_ndims < 0:
    raise ValueError(
        'Input has fewer ndims ({}) than event shape ndims ({}).'.format(
            tensorshape_util.rank(input_tensorshape), event_shape_in_ndims))

  input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims]
  input_event_tensorshape = input_tensorshape[input_non_event_ndims:]

  # Check that `input_event_shape_` and `event_shape_in` are compatible in the
  # sense that they have equal entries in any position that isn't a `-1` in
  # `event_shape_in`. Note that our validations at construction time ensure
  # there is at most one such entry in `event_shape_in`.
  event_shape_in_ = tf.get_static_value(event_shape_in)
  is_validated = (
      tensorshape_util.is_fully_defined(input_event_tensorshape) and
      event_shape_in_ is not None)
  if is_validated:
    input_event_shape_ = np.int32(input_event_tensorshape)
    mask = event_shape_in_ >= 0
    explicit_input_event_shape_ = input_event_shape_[mask]
    explicit_event_shape_in_ = event_shape_in_[mask]
    if not all(explicit_input_event_shape_ == explicit_event_shape_in_):
      raise ValueError(
          'Input `event_shape` does not match `event_shape_in`. '
          '({} vs {}).'.format(input_event_shape_, event_shape_in_))

  event_tensorshape_out = tensorshape_util.constant_value_as_shape(
      event_shape_out)
  if tensorshape_util.rank(event_tensorshape_out) is None:
    output_tensorshape = tf.TensorShape(None)
  else:
    output_tensorshape = tensorshape_util.concatenate(
        input_non_event_tensorshape, event_tensorshape_out)

  return output_tensorshape, is_validated
Esempio n. 20
0
def _replace_event_shape_in_shape_tensor(
    input_shape, event_shape_in, event_shape_out, validate_args):
  """Replaces the rightmost dims in a `Tensor` representing a shape.

  Args:
    input_shape: a rank-1 `Tensor` of integers
    event_shape_in: the event shape expected to be present in rightmost dims
      of `shape_in`.
    event_shape_out: the event shape with which to replace `event_shape_in` in
      the rightmost dims of `input_shape`.
    validate_args: Python `bool` indicating whether arguments should
      be checked for correctness.

  Returns:
    output_shape: A rank-1 integer `Tensor` with the same contents as
      `input_shape` except for the event dims, which are replaced with
      `event_shape_out`.
  """
  output_tensorshape, is_validated = _replace_event_shape_in_tensorshape(
      tensorshape_util.constant_value_as_shape(input_shape),
      event_shape_in,
      event_shape_out)

  # TODO(b/124240153): Remove map(tf.identity, deps) once tf.function
  # correctly supports control_dependencies.
  validation_dependencies = (
      map(tf.identity, (event_shape_in, event_shape_out))
      if validate_args else ())

  if (tensorshape_util.is_fully_defined(output_tensorshape) and
      (is_validated or not validate_args)):
    with tf.control_dependencies(validation_dependencies):
      output_shape = tf.convert_to_tensor(
          output_tensorshape, name='output_shape', dtype_hint=tf.int32)
    return output_shape, output_tensorshape

  with tf.control_dependencies(validation_dependencies):
    event_shape_in_ndims = (
        tf.size(event_shape_in)
        if tensorshape_util.num_elements(event_shape_in.shape) is None else
        tensorshape_util.num_elements(event_shape_in.shape))
    input_non_event_shape, input_event_shape = tf.split(
        input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

  additional_assertions = []
  if is_validated:
    pass
  elif validate_args:
    # Check that `input_event_shape` and `event_shape_in` are compatible in the
    # sense that they have equal entries in any position that isn't a `-1` in
    # `event_shape_in`. Note that our validations at construction time ensure
    # there is at most one such entry in `event_shape_in`.
    mask = event_shape_in >= 0
    explicit_input_event_shape = tf.boolean_mask(input_event_shape, mask=mask)
    explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask)
    additional_assertions.append(
        assert_util.assert_equal(
            explicit_input_event_shape,
            explicit_event_shape_in,
            message='Input `event_shape` does not match `event_shape_in`.'))
    # We don't explicitly additionally verify
    # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split`
    # already makes this assertion.

  with tf.control_dependencies(additional_assertions):
    output_shape = tf.concat([input_non_event_shape, event_shape_out], axis=0,
                             name='output_shape')

  return output_shape, output_tensorshape
Esempio n. 21
0
    def __init__(self,
                 image_shape: tuple,
                 conditional_shape: tuple = None,
                 num_resnet: int = 5,
                 num_hierarchies: int = 3,
                 num_filters: int = 160,
                 num_logistic_mix: int = 10,
                 receptive_field_dims: tuple = (3, 3),
                 dropout_p: float = 0.5,
                 resnet_activation: str = 'concat_elu',
                 l2_weight: float = 0.,
                 use_weight_norm: bool = True,
                 use_data_init: bool = True,
                 high: int = 255,
                 low: int = 0,
                 dtype=tf.float32,
                 name: str = 'PixelCNN') -> None:
        """
        Construct Pixel CNN++ distribution.

        Parameters
        ----------
        image_shape
            3D `TensorShape` or tuple for the `[height, width, channels]` dimensions of the image.
        conditional_shape
            `TensorShape` or tuple for the shape of the conditional input, or `None` if there is no conditional input.
        num_resnet
            The number of layers (shown in Figure 2 of [2]) within each highest-level block of Figure 2 of [1].
        num_hierarchies
            The number of highest-level blocks (separated by expansions/contractions of dimensions in Figure 2 of [1].)
        num_filters
            The number of convolutional filters.
        num_logistic_mix
            Number of components in the logistic mixture distribution.
        receptive_field_dims
            Height and width in pixels of the receptive field of the convolutional layers above and to the left
            of a given pixel. The width (second element of the tuple) should be odd. Figure 1 (middle) of [2]
            shows a receptive field of (3, 5) (the row containing the current pixel is included in the height).
            The default of (3, 3) was used to produce the results in [1].
        dropout_p
            The dropout probability. Should be between 0 and 1.
        resnet_activation
            The type of activation to use in the resnet blocks. May be 'concat_elu', 'elu', or 'relu'.
        use_weight_norm
            If `True` then use weight normalization (works only in Eager mode).
        use_data_init
            If `True` then use data-dependent initialization (has no effect if `use_weight_norm` is `False`).
        high
            The maximum value of the input data (255 for an 8-bit image).
        low
            The minimum value of the input data.
        dtype
            Data type of the `Distribution`.
        name
            The name of the `Distribution`.
        """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            super(PixelCNN, self).__init__(
                dtype=dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=False,
                allow_nan_stats=True,
                parameters=parameters,
                name=name)

            if not tensorshape_util.is_fully_defined(image_shape):
                raise ValueError('`image_shape` must be fully defined.')

            if conditional_shape is not None and not tensorshape_util.is_fully_defined(
                    conditional_shape):
                raise ValueError('`conditional_shape` must be fully defined`')

            if tensorshape_util.rank(image_shape) != 3:
                raise ValueError(
                    '`image_shape` must have length 3, representing [height, width, channels] dimensions.'
                )

            self._high = tf.cast(high, self.dtype)
            self._low = tf.cast(low, self.dtype)
            self._num_logistic_mix = num_logistic_mix
            self.network = _PixelCNNNetwork(
                dropout_p=dropout_p,
                num_resnet=num_resnet,
                num_hierarchies=num_hierarchies,
                num_filters=num_filters,
                num_logistic_mix=num_logistic_mix,
                receptive_field_dims=receptive_field_dims,
                resnet_activation=resnet_activation,
                l2_weight=l2_weight,
                use_weight_norm=use_weight_norm,
                use_data_init=use_data_init,
                dtype=dtype)

            image_input_shape = tensorshape_util.concatenate([None],
                                                             image_shape)
            if conditional_shape is None:
                input_shape = image_input_shape
            else:
                conditional_input_shape = tensorshape_util.concatenate(
                    [None], conditional_shape)
                input_shape = [image_input_shape, conditional_input_shape]

            self.image_shape = image_shape
            self.conditional_shape = conditional_shape
            self.network.build(input_shape)
Esempio n. 22
0
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
    """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
    # Implementation details:
    # Extend length N / 2 1-D array x to length N by zero padding onto the end.
    # Then, set
    #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
    # It is not hard to see that
    #   F[x]_k Conj(F[x]_k) = F[R]_k, where
    #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
    # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

    # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
    # based version of estimating RXX.
    # Note that this is a special case of the Wiener-Khinchin Theorem.
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        # Rotate dimensions of x in order to put axis at the rightmost dim.
        # FFT op requires this.
        rank = ps.rank(x)
        if axis < 0:
            axis = rank + axis
        shift = rank - 1 - axis
        # Suppose x.shape[axis] = T, so there are T 'time' steps.
        #   ==> x_rotated.shape = B + [T],
        # where B is x_rotated's batch shape.
        x_rotated = distribution_util.rotate_transpose(x, shift)

        if center:
            x_rotated = x_rotated - tf.reduce_mean(
                x_rotated, axis=-1, keepdims=True)

        # x_len = N / 2 from above explanation.  The length of x along axis.
        # Get a value for x_len that works in all cases.
        x_len = ps.shape(x_rotated)[-1]

        # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
        # the moment is necessary so that all FFT implementations work.
        # Zero pad to the next power of 2 greater than 2 * x_len, which equals
        # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
        x_len_float64 = ps.cast(x_len, np.float64)
        target_length = ps.pow(np.float64(2.),
                               ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.)))
        pad_length = ps.cast(target_length - x_len_float64, np.int32)

        # We should have:
        # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
        #                     = B + [T + pad_length]
        x_rotated_pad = distribution_util.pad(x_rotated,
                                              axis=-1,
                                              back=True,
                                              count=pad_length)

        dtype = x.dtype
        if not dtype_util.is_complex(dtype):
            if not dtype_util.is_floating(dtype):
                raise TypeError(
                    'Argument x must have either float or complex dtype'
                    ' found: {}'.format(dtype))
            x_rotated_pad = tf.complex(
                x_rotated_pad,
                dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.))

        # Autocorrelation is IFFT of power-spectral density (up to some scaling).
        fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
        spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
        # shifted_product is R[m] from above detailed explanation.
        # It is the inner product sum_n X[n] * Conj(X[n - m]).
        shifted_product = tf.signal.ifft(spectral_density)

        # Cast back to real-valued if x was real to begin with.
        shifted_product = tf.cast(shifted_product, dtype)

        # Figure out if we can deduce the final static shape, and set max_lags.
        # Use x_rotated as a reference, because it has the time dimension in the far
        # right, and was created before we performed all sorts of crazy shape
        # manipulations.
        know_static_shape = True
        if not tensorshape_util.is_fully_defined(x_rotated.shape):
            know_static_shape = False
        if max_lags is None:
            max_lags = x_len - 1
        else:
            max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
            max_lags_ = tf.get_static_value(max_lags)
            if max_lags_ is None or not know_static_shape:
                know_static_shape = False
                max_lags = tf.minimum(x_len - 1, max_lags)
            else:
                max_lags = min(x_len - 1, max_lags_)

        # Chop off the padding.
        # We allow users to provide a huge max_lags, but cut it off here.
        # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
        shifted_product_chopped = shifted_product[..., :max_lags + 1]

        # If possible, set shape.
        if know_static_shape:
            chopped_shape = tensorshape_util.as_list(x_rotated.shape)
            chopped_shape[-1] = min(x_len, max_lags + 1)
            tensorshape_util.set_shape(shifted_product_chopped, chopped_shape)

        # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
        # other terms were zeros arising only due to zero padding.
        # `denominator = (N / 2 - m)` (defined below) is the proper term to
        # divide by to make this an unbiased estimate of the expectation
        # E[X[n] Conj(X[n - m])].
        x_len = ps.cast(x_len, dtype_util.real_dtype(dtype))
        max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype))
        denominator = x_len - ps.range(0., max_lags + 1.)
        denominator = ps.cast(denominator, dtype)
        shifted_product_rotated = shifted_product_chopped / denominator

        if normalize:
            shifted_product_rotated /= shifted_product_rotated[..., :1]

        # Transpose dimensions back to those of x.
        return distribution_util.rotate_transpose(shifted_product_rotated,
                                                  -shift)
Esempio n. 23
0
def fill_triangular(x, upper=False, name=None):
    """Creates a (batch of) triangular matrix from a vector of inputs.

  Created matrix can be lower- or upper-triangular. (It is more efficient to
  create the matrix as upper or lower, rather than transpose.)

  Triangular matrix elements are filled in a clockwise spiral. See example,
  below.

  If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is
  `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.

  Example:

  ```python
  fill_triangular([1, 2, 3, 4, 5, 6])
  # ==> [[4, 0, 0],
  #      [6, 5, 0],
  #      [3, 2, 1]]

  fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
  # ==> [[1, 2, 3],
  #      [0, 5, 6],
  #      [0, 0, 4]]
  ```

  The key trick is to create an upper triangular matrix by concatenating `x`
  and a tail of itself, then reshaping.

  Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M`
  from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x`
  contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5`
  (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with
  the first (`n = 5`) elements removed and reversed:

  ```python
  x = np.arange(15) + 1
  xc = np.concatenate([x, x[5:][::-1]])
  # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
  #            12, 11, 10, 9, 8, 7, 6])

  # (We add one to the arange result to disambiguate the zeros below the
  # diagonal of our upper-triangular matrix from the first entry in `x`.)

  # Now, when reshapedlay this out as a matrix:
  y = np.reshape(xc, [5, 5])
  # ==> array([[ 1,  2,  3,  4,  5],
  #            [ 6,  7,  8,  9, 10],
  #            [11, 12, 13, 14, 15],
  #            [15, 14, 13, 12, 11],
  #            [10,  9,  8,  7,  6]])

  # Finally, zero the elements below the diagonal:
  y = np.triu(y, k=0)
  # ==> array([[ 1,  2,  3,  4,  5],
  #            [ 0,  7,  8,  9, 10],
  #            [ 0,  0, 13, 14, 15],
  #            [ 0,  0,  0, 12, 11],
  #            [ 0,  0,  0,  0,  6]])
  ```

  From this example we see that the resuting matrix is upper-triangular, and
  contains all the entries of x, as desired. The rest is details:
  - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills
    `n / 2` rows and half of an additional row), but the whole scheme still
    works.
  - If we want a lower triangular matrix instead of an upper triangular,
    we remove the first `n` elements from `x` rather than from the reversed
    `x`.

  For additional comparisons, a pure numpy version of this function can be found
  in `distribution_util_test.py`, function `_fill_triangular`.

  Args:
    x: `Tensor` representing lower (or upper) triangular elements.
    upper: Python `bool` representing whether output matrix should be upper
      triangular (`True`) or lower triangular (`False`, default).
    name: Python `str`. The name to give this op.

  Returns:
    tril: `Tensor` with lower (or upper) triangular elements filled from `x`.

  Raises:
    ValueError: if `x` cannot be mapped to a triangular matrix.
  """

    with tf.name_scope(name or 'fill_triangular'):
        x = tf.convert_to_tensor(x, name='x')
        m = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(x.shape, 1)[-1])
        if m is not None:
            # Formula derived by solving for n: m = n(n+1)/2.
            m = np.int32(m)
            n = np.sqrt(0.25 + 2. * m) - 0.5
            if n != np.floor(n):
                raise ValueError(
                    'Input right-most shape ({}) does not '
                    'correspond to a triangular matrix.'.format(m))
            n = np.int32(n)
            static_final_shape = tensorshape_util.concatenate(
                x.shape[:-1], [n, n])
        else:
            m = tf.shape(x)[-1]
            # For derivation, see above. Casting automatically lops off the 0.5, so we
            # omit it.  We don't validate n is an integer because this has
            # graph-execution cost; an error will be thrown from the reshape, below.
            n = tf.cast(tf.sqrt(0.25 + tf.cast(2 * m, dtype=tf.float32)),
                        dtype=tf.int32)
            static_final_shape = tensorshape_util.concatenate(
                tensorshape_util.with_rank_at_least(x.shape, 1)[:-1],
                [None, None])

        # Try it out in numpy:
        #  n = 3
        #  x = np.arange(n * (n + 1) / 2)
        #  m = x.shape[0]
        #  n = np.int32(np.sqrt(.25 + 2 * m) - .5)
        #  x_tail = x[(m - (n**2 - m)):]
        #  np.concatenate([x_tail, x[::-1]], 0).reshape(n, n)  # lower
        #  # ==> array([[3, 4, 5],
        #               [5, 4, 3],
        #               [2, 1, 0]])
        #  np.concatenate([x, x_tail[::-1]], 0).reshape(n, n)  # upper
        #  # ==> array([[0, 1, 2],
        #               [3, 4, 5],
        #               [5, 4, 3]])
        #
        # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
        # correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
        # Furthermore observe that:
        #   m - (n**2 - m)
        #   = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
        #   = 2 (n**2 / 2 + n / 2) - n**2
        #   = n**2 + n - n**2
        #   = n
        ndims = prefer_static.rank(x)
        if upper:
            x_list = [x, tf.reverse(x[..., n:], axis=[ndims - 1])]
        else:
            x_list = [x[..., n:], tf.reverse(x, axis=[ndims - 1])]
        new_shape = (tensorshape_util.as_list(static_final_shape)
                     if tensorshape_util.is_fully_defined(static_final_shape)
                     else tf.concat([tf.shape(x)[:-1], [n, n]], axis=0))
        x = tf.reshape(tf.concat(x_list, axis=-1), new_shape)
        x = tf.linalg.band_part(x,
                                num_lower=(0 if upper else -1),
                                num_upper=(-1 if upper else 0))
        tensorshape_util.set_shape(x, static_final_shape)
        return x
Esempio n. 24
0
    def _sample_n(self, n, seed=None):
        stream = seed_stream.SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                input_tensor=self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:])
        if stride is None:
            stride = tf.reduce_prod(input_tensor=tf.shape(
                input=self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
Esempio n. 25
0
def minimize(loss_fn,
             num_steps,
             optimizer,
             trainable_variables=None,
             trace_fn=_trace_loss,
             name='minimize'):
    """Minimize a loss function using a provided optimizer.

  Args:
    loss_fn: Python callable with signature `loss = loss_fn()`, where `loss`
      is a `Tensor` loss to be minimized.
    num_steps: Python `int` number of steps to run the optimizer.
    optimizer: Optimizer instance to use. This may be a TF1-style
      `tf.train.Optimizer`, TF2-style `tf.optimizers.Optimizer`, or any Python
      object that implements `optimizer.apply_gradients(grads_and_vars)`.
    trainable_variables: list of `tf.Variable` instances to optimize with
      respect to. If `None`, defaults to the set of all variables accessed
      during the execution of `loss_fn()`.
      Default value: `None`.
    trace_fn: Python callable with signature `state = trace_fn(
      loss, grads, variables)`, where `state` may be a `Tensor` or nested
      structure of `Tensor`s. The state values are accumulated (by `tf.scan`)
      and returned. The default `trace_fn` simply returns the loss, but in
      general can depend on the gradients and variables (if
      `trainable_variables` is not `None` then `variables==trainable_variables`;
      otherwise it is the list of all variables accessed during execution of
      `loss_fn()`), as well as any other quantities captured in the closure of
      `trace_fn`, for example, statistics of a variational distribution.
      Default value: `lambda loss, grads, variables: loss`.
    name: Python `str` name prefixed to ops created by this function.
      Default value: 'minimize'.

  Returns:
    trace: `Tensor` or nested structure of `Tensor`s, according to the
      return type of `trace_fn`. Each `Tensor` has an added leading dimension
      of size `num_steps`, packing the trajectory of the result over the course
      of the optimization.

  ### Examples

  To minimize the scalar function `(x - 5)**2`:

  ```python
  x = tf.Variable(0.)
  loss_fn = lambda: (x - 5.)**2
  losses = tfp.math.minimize(loss_fn,
                             num_steps=100,
                             optimizer=tf.optimizers.Adam(learning_rate=0.1))

  # In TF2/eager mode, the optimization runs immediately.
  print("optimized value is {} with loss {}".format(x, losses[-1]))
  ```

  In graph mode (e.g., inside of `tf.function` wrapping), retrieving any Tensor
  that depends on the minimization op will trigger the optimization:

  ```python
  with tf.control_dependencies([losses]):
    optimized_x = tf.identity(x)  # Use a dummy op to attach the dependency.
  ```

  In some cases, we may want to track additional context inside the
  optimization. We can do this by defining a custom `trace_fn`. Note that
  the `trace_fn` is passed the loss and gradients, but it may also report the
  values of trainable variables or other derived quantities by capturing them in
  its closure. For example, we can capture `x` and track its value over the
  optimization:

  ```python
  # `x` is the tf.Variable instance defined above.
  trace_fn = lambda loss, grads, variables: {'loss': loss, 'x': x}
  trace = tfp.vi.minimize(loss_fn, num_steps=100,
                          optimizer=tf.optimizers.Adam(0.1),
                          trace_fn=trace_fn)
  print(trace['loss'].shape,   # => [100]
        trace['x'].shape)      # => [100]
  ```
  """
    @tf.function(autograph=False)
    def train_loop_body(old_result, step):  # pylint: disable=unused-argument
        """Run a single optimization step."""
        with tf.GradientTape(
                watch_accessed_variables=trainable_variables is None) as tape:
            for v in trainable_variables or []:
                tape.watch(v)
            loss = loss_fn()
        watched_variables = tape.watched_variables()
        grads = tape.gradient(loss, watched_variables)
        train_op = optimizer.apply_gradients(zip(grads, watched_variables))
        with tf.control_dependencies([train_op]):
            state = trace_fn(tf.identity(loss),
                             [tf.identity(g) for g in grads],
                             [tf.identity(v) for v in watched_variables])
        return state

    with tf.name_scope(name) as name:
        # Compute the shape of the trace without executing the graph, if possible.
        concrete_loop_body = train_loop_body.get_concrete_function(
            tf.TensorSpec([]), tf.TensorSpec([]))  # Inputs ignored.
        if all([
                tensorshape_util.is_fully_defined(shape)
                for shape in tf.nest.flatten(concrete_loop_body.output_shapes)
        ]):
            state_initializer = tf.nest.map_structure(
                lambda shape, dtype: tf.zeros(shape, dtype=dtype),
                concrete_loop_body.output_shapes,
                concrete_loop_body.output_dtypes)
            initial_trace_step = None
        else:
            state_initializer = concrete_loop_body(
                tf.convert_to_tensor(0.),
                tf.convert_to_tensor(0.))  # Inputs ignored.
            num_steps = num_steps - 1
            initial_trace_step = state_initializer

        # TODO(b/136103064): Rewrite as explicit `while_loop` to support custom
        # convergence criteria and Tensor-valued `num_steps`, and avoid
        # re-tracing the train loop body.
        trace = tf.scan(train_loop_body,
                        elems=np.arange(num_steps),
                        initializer=state_initializer)
        if initial_trace_step is not None:
            trace = tf.nest.map_structure(
                lambda a, b: tf.concat([a[tf.newaxis, ...], b], axis=0),
                initial_trace_step, trace)
        return trace
Esempio n. 26
0
def pad_batch_dimension_for_multiple_chains(observed_time_series, model,
                                            chain_batch_shape):
    """"Expand the observed time series with extra batch dimension(s)."""
    # Running with multiple chains introduces an extra batch dimension. In
    # general we also need to pad the observed time series with a matching batch
    # dimension.
    #
    # For example, suppose our model has batch shape [3, 4] and
    # the observed time series has shape `concat([[5], [3, 4], [100])`,
    # corresponding to `sample_shape`, `batch_shape`, and `num_timesteps`
    # respectively. The model will produce distributions with batch shape
    # `concat([chain_batch_shape, [3, 4]])`, so we pad `observed_time_series` to
    # have matching shape `[5, 1, 3, 4, 100]`, where the added `1` dimension
    # between the sample and batch shapes will broadcast to `chain_batch_shape`.

    [  # Extract mask and guarantee `event_ndims=2`.
        observed_time_series, is_missing
    ] = canonicalize_observed_time_series_with_mask(observed_time_series)

    event_ndims = 2  # event_shape = [num_timesteps, observation_size=1]

    model_batch_ndims = (tensorshape_util.rank(model.batch_shape) if
                         tensorshape_util.rank(model.batch_shape) is not None
                         else tf.shape(model.batch_shape_tensor())[0])

    # Compute ndims from chain_batch_shape.
    chain_batch_shape = tf.convert_to_tensor(value=chain_batch_shape,
                                             name='chain_batch_shape',
                                             dtype=tf.int32)
    if not tensorshape_util.is_fully_defined(chain_batch_shape.shape):
        raise ValueError(
            'Batch shape must have static rank. (given: {})'.format(
                chain_batch_shape))
    if tensorshape_util.rank(chain_batch_shape.shape) == 0:
        # expand int `k` to `[k]`.
        chain_batch_shape = chain_batch_shape[tf.newaxis]
    chain_batch_ndims = tf.compat.dimension_value(chain_batch_shape.shape[0])

    def do_padding(observed_time_series_tensor):
        current_sample_shape = ps.shape(
            observed_time_series_tensor)[:-(model_batch_ndims + event_ndims)]
        current_batch_and_event_shape = ps.shape(
            observed_time_series_tensor)[-(model_batch_ndims + event_ndims):]
        return tf.reshape(tensor=observed_time_series_tensor,
                          shape=ps.concat([
                              current_sample_shape,
                              ps.ones([chain_batch_ndims], dtype=tf.int32),
                              current_batch_and_event_shape
                          ],
                                          axis=0))

    # Padding is only needed if the observed time series has sample shape.
    observed_time_series = ps.cond(
        ps.rank(observed_time_series) > model_batch_ndims + event_ndims,
        lambda: do_padding(observed_time_series), lambda: observed_time_series)
    if is_missing is not None:
        is_missing = ps.cond(
            ps.rank(is_missing) > model_batch_ndims + event_ndims,
            lambda: do_padding(is_missing), lambda: is_missing)
    return missing_values_util.MaskedTimeSeries(observed_time_series,
                                                is_missing=is_missing)
Esempio n. 27
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code
            # path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seed)
            stream = SeedStream(seed, salt='Mixture')

            for c in range(self.num_components):
                samples.append(self.components[c].sample(n, seed=stream()))
            stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
            x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
            npdt = dtype_util.as_numpy_dtype(x.dtype)
            mask = tf.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=npdt(1),
                off_value=npdt(0))  # [n, B, k]
            mask = distribution_util.pad_mixture_dimensions(
                mask, self, self._cat,
                tensorshape_util.rank(
                    self._static_event_shape))  # [n, B, k, [1]*e]
            return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        n = tf.convert_to_tensor(n, name='n')
        static_n = tf.get_static_value(n)
        n = int(static_n) if static_n is not None else n
        cat_samples = self.cat.sample(n, seed=seed)

        static_samples_shape = cat_samples.shape
        if tensorshape_util.is_fully_defined(static_samples_shape):
            samples_shape = tensorshape_util.as_list(static_samples_shape)
            samples_size = tensorshape_util.num_elements(static_samples_shape)
        else:
            samples_shape = tf.shape(cat_samples)
            samples_size = tf.size(cat_samples)
        static_batch_shape = self.batch_shape
        if tensorshape_util.is_fully_defined(static_batch_shape):
            batch_shape = tensorshape_util.as_list(static_batch_shape)
            batch_size = tensorshape_util.num_elements(static_batch_shape)
        else:
            batch_shape = tf.shape(cat_samples)[1:]
            batch_size = tf.reduce_prod(batch_shape)
        static_event_shape = self.event_shape
        if tensorshape_util.is_fully_defined(static_event_shape):
            event_shape = np.array(
                tensorshape_util.as_list(static_event_shape), dtype=np.int32)
        else:
            event_shape = None

        # Get indices into the raw cat sampling tensor. We will
        # need these to stitch sample values back out after sampling
        # within the component partitions.
        samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                         samples_shape)

        # Partition the raw indices so that we can use
        # dynamic_stitch later to reconstruct the samples from the
        # known partitions.
        partitioned_samples_indices = tf.dynamic_partition(
            data=samples_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)

        # Copy the batch indices n times, as we will need to know
        # these to pull out the appropriate rows within the
        # component partitions.
        batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]),
                                       samples_shape)

        # Explanation of the dynamic partitioning below:
        #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
        # Suppose partitions are:
        #     [1 1 0 0 1 1]
        # After partitioning, batch indices are cut as:
        #     [batch_indices[x] for x in 2, 3]
        #     [batch_indices[x] for x in 0, 1, 4, 5]
        # i.e.
        #     [1 1] and [0 0 0 0]
        # Now we sample n=2 from part 0 and n=4 from part 1.
        # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
        # and for part 1 we want samples from batch entries 0, 0, 0, 0
        #   (samples 0, 1, 2, 3).
        partitioned_batch_indices = tf.dynamic_partition(
            data=batch_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)
        samples_class = [None for _ in range(self.num_components)]

        stream = SeedStream(seed, salt='Mixture')

        for c in range(self.num_components):
            n_class = tf.size(partitioned_samples_indices[c])
            samples_class_c = self.components[c].sample(n_class, seed=stream())

            if event_shape is None:
                batch_ndims = prefer_static.rank_from_shape(batch_shape)
                event_shape = tf.shape(samples_class_c)[1 + batch_ndims:]

            # Pull out the correct batch entries from each index.
            # To do this, we may have to flatten the batch shape.

            # For sample s, batch element b of component c, we get the
            # partitioned batch indices from
            # partitioned_batch_indices[c]; and shift each element by
            # the sample index. The final lookup can be thought of as
            # a matrix gather along locations (s, b) in
            # samples_class_c where the n_class rows correspond to
            # samples within this component and the batch_size columns
            # correspond to batch elements within the component.
            #
            # Thus the lookup index is
            #   lookup[c, i] = batch_size * s[i] + b[c, i]
            # for i = 0 ... n_class[c] - 1.
            lookup_partitioned_batch_indices = (
                batch_size * tf.range(n_class) + partitioned_batch_indices[c])
            samples_class_c = tf.reshape(
                samples_class_c,
                tf.concat([[n_class * batch_size], event_shape], 0))
            samples_class_c = tf.gather(samples_class_c,
                                        lookup_partitioned_batch_indices,
                                        name='samples_class_c_gather')
            samples_class[c] = samples_class_c

        # Stitch back together the samples across the components.
        lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices,
                                         data=samples_class)
        # Reshape back to proper sample, batch, and event shape.
        ret = tf.reshape(lhs_flat_ret,
                         tf.concat([samples_shape, event_shape], 0))
        tensorshape_util.set_shape(
            ret,
            tensorshape_util.concatenate(static_samples_shape,
                                         self.event_shape))
        return ret
Esempio n. 28
0
def make_momentum_distribution(state_parts,
                               batch_shape,
                               running_variance_parts=None,
                               shard_axis_names=None):
    """Construct a momentum distribution from the running variance.

  This uses a running variance to construct a momentum distribution with the
  correct batch_shape and event_shape.

  Args:
    state_parts: List of `Tensor`.
    batch_shape: Batch shape.
    running_variance_parts: Optional, list of `Tensor`
       outputs of `tfp.experimental.stats.RunningVariance.variance()`. Defaults
       to ones with the same shape as state_parts.
    shard_axis_names: A structure of string names indicating how members of the
      state are sharded.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
    if running_variance_parts is None:
        running_variance_parts = tf.nest.map_structure(tf.ones_like,
                                                       state_parts)
    distributions = []
    batch_ndims = ps.rank_from_shape(batch_shape)
    use_sharded_jd = True
    if shard_axis_names is None:
        use_sharded_jd = False
        shard_axis_names = [None] * len(state_parts)
    for variance_part, state_part, shard_axes in zip(running_variance_parts,
                                                     state_parts,
                                                     shard_axis_names):
        event_shape = state_part.shape[batch_ndims:]
        if not tensorshape_util.is_fully_defined(event_shape):
            event_shape = ps.shape(state_part,
                                   name='state_part_shp')[batch_ndims:]
        variance_tiled = tf.broadcast_to(
            variance_part, ps.concat([batch_shape, event_shape], axis=0))
        nevt = ps.cast(ps.reduce_prod(event_shape), tf.int32)
        variance_flattened = tf.reshape(
            variance_tiled, ps.concat([batch_shape, [nevt]], axis=0))

        distribution = _CompositeTransformedDistribution(
            bijector=reshape.Reshape(event_shape_out=event_shape,
                                     name='reshape_mvnpfl'),
            distribution=(
                _CompositeMultivariateNormalPrecisionFactorLinearOperator(
                    precision_factor=tf.linalg.LinearOperatorDiag(
                        tf.math.sqrt(variance_flattened)),
                    precision=tf.linalg.LinearOperatorDiag(variance_flattened),
                    name='momentum')))
        if shard_axes:
            distribution = sharded.Sharded(distribution,
                                           shard_axis_name=shard_axes)
        distributions.append(distribution)
    if use_sharded_jd:
        jd = _CompositeShardedJointDistributionSequential(distributions)
    else:
        jd = _CompositeJointDistributionSequential(distributions)
    return maybe_make_list_and_batch_broadcast(jd, batch_shape)
Esempio n. 29
0
    def testBijector(self, bijector_name, data):
        tfp_hps.guitar_skip_if_matches('Tanh', bijector_name, 'b/144163991')

        bijector, event_dim = self._draw_bijector(bijector_name, data)

        # Forward mapping: Check differentiation through forward mapping with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        xs = self._draw_domain_tensor(bijector, data, event_dim)
        wrt_vars = [xs] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ys = bijector.forward(xs + 0)
        grads = tape.gradient(ys, wrt_vars)
        assert_no_none_grad(bijector, 'forward', wrt_vars, grads)

        # For scalar bijectors, verify correctness of the _is_increasing method.
        # TODO(b/148459057): Except, don't verify Softfloor on Guitar because
        # of numerical problem.
        def exception(bijector):
            if not tfp_hps.running_under_guitar():
                return False
            if isinstance(bijector, tfb.Softfloor):
                return True
            if is_invert(bijector):
                return exception(bijector.bijector)
            return False

        if (bijector.forward_min_event_ndims == 0
                and bijector.inverse_min_event_ndims == 0
                and not exception(bijector)):
            dydx = grads[0]
            hp.note('dydx: {}'.format(dydx))
            isfinite = tf.math.is_finite(dydx)
            incr_or_slope_eq0 = bijector._internal_is_increasing() | tf.equal(
                dydx, 0)  # pylint: disable=protected-access
            self.assertAllEqual(
                isfinite & incr_or_slope_eq0,
                isfinite & (dydx >= 0) | tf.zeros_like(incr_or_slope_eq0))

        # FLDJ: Check differentiation through forward log det jacobian with
        # respect to the input and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.forward_min_event_ndims,
                         max_value=xs.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=True)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `forward_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.forward_log_det_jacobian(
                    xs + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'forward_log_det_jacobian', wrt_vars,
                            grads)

        # Inverse mapping: Check differentiation through inverse mapping with
        # respect to the codomain "input" and parameter variables.  Also check that
        # any variables are not referenced overmuch.
        ys = self._draw_codomain_tensor(bijector, data, event_dim)
        wrt_vars = [ys] + [
            v for v in bijector.trainable_variables if v.dtype.is_floating
        ]
        with tf.GradientTape() as tape:
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse` of {}'.format(bijector)):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                xs = bijector.inverse(ys + 0)
        grads = tape.gradient(xs, wrt_vars)
        assert_no_none_grad(bijector, 'inverse', wrt_vars, grads)

        # ILDJ: Check differentiation through inverse log det jacobian with respect
        # to the codomain "input" and parameter variables.  Also check that any
        # variables are not referenced overmuch.
        event_ndims = data.draw(
            hps.integers(min_value=bijector.inverse_min_event_ndims,
                         max_value=ys.shape.ndims))
        with tf.GradientTape() as tape:
            max_permitted = _ldj_tensor_conversions_allowed(bijector,
                                                            is_forward=False)
            with tfp_hps.assert_no_excessive_var_usage(
                    'method `inverse_log_det_jacobian` of {}'.format(bijector),
                    max_permissible=max_permitted):
                tape.watch(wrt_vars)
                # TODO(b/73073515): Fix graph mode gradients with bijector caching.
                ldj = bijector.inverse_log_det_jacobian(
                    ys + 0, event_ndims=event_ndims)
        grads = tape.gradient(ldj, wrt_vars)
        assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars,
                            grads)

        # Verify that `_is_permutation` implies constant zero Jacobian.
        if bijector._is_permutation:
            self.assertTrue(bijector._is_constant_jacobian)
            self.assertAllEqual(ldj, 0.)

        # Verify correctness of batch shape.
        xs_batch_shapes = tf.nest.map_structure(
            lambda x, nd: ps.shape(x)[:ps.rank(x) - nd], xs,
            bijector.inverse_event_ndims(event_ndims))
        empirical_batch_shape = functools.reduce(
            ps.broadcast_shape,
            nest.flatten_up_to(bijector.forward_min_event_ndims,
                               xs_batch_shapes))
        batch_shape = bijector.experimental_batch_shape(
            y_event_ndims=event_ndims)
        if tensorshape_util.is_fully_defined(batch_shape):
            self.assertAllEqual(empirical_batch_shape, batch_shape)
        self.assertAllEqual(
            empirical_batch_shape,
            bijector.experimental_batch_shape_tensor(
                y_event_ndims=event_ndims))

        # Check that the outputs of forward_dtype and inverse_dtype match the dtypes
        # of the outputs of forward and inverse.
        self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype))
        self.assertAllEqualNested(xs.dtype, bijector.inverse_dtype(ys.dtype))
Esempio n. 30
0
def _setup_mcmc(model, n_chains, *, init_position=None, seed=None, **pins):
    """Construct bijector and transforms needed for windowed MCMC.

  This pins the initial model, constructs a bijector that unconstrains and
  flattens each dimension and adds a leading batch shape of `n_chains`,
  initializes a point in the unconstrained space, and constructs a transformed
  log probability using the bijector.

  Note that we must manually construct this target log probability instead of
  using a transformed transition kernel because the TTK assumes the shape
  in is the same as the shape out.

  Args:
    model: `tfd.JointDistribution`
      The model to sample from.
    n_chains: list of ints
      Number of chains (independent examples) to run.
    init_position: Optional
      Structure of tensors at which to initialize sampling. Should have the
      same shape and structure as
      `model.experimental_pin(**pins).sample_unpinned(n_chains)`.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    **pins:
      Values passed to `model.experimental_pin`.


  Returns:
    target_log_prob_fn: Callable on the transformed space.
    initial_transformed_position: `tf.Tensor`, sampled from a uniform (-2, 2).
    bijector: `tfb.Bijector` instance, which unconstrains and flattens.
    step_broadcast_fn: Callable to broadcast step size over latent structure.
    batch_shape: Batch shape of the model.
    shard_axis_names: Shard axis names for the model
  """
    pinned_model = model.experimental_pin(**pins) if pins else model
    bijector, step_bijector = _get_flat_unconstraining_bijector(pinned_model)

    if init_position is None:
        raw_init_dist = initialization.init_near_unconstrained_zero(
            pinned_model)
        init_position = initialization.retry_init(
            raw_init_dist.sample,
            target_fn=pinned_model.unnormalized_log_prob,
            sample_shape=n_chains,
            seed=seed)

    initial_transformed_position = tf.nest.map_structure(
        tf.identity, bijector.forward(init_position))

    batch_shape = pinned_model.batch_shape
    if tf.nest.is_nested(batch_shape):
        batch_shape = functools.reduce(tf.broadcast_static_shape,
                                       tf.nest.flatten(batch_shape))

    if not tensorshape_util.is_fully_defined(batch_shape):
        batch_shape = pinned_model.batch_shape_tensor()
        if tf.nest.is_nested(batch_shape):
            batch_shape = functools.reduce(tf.broadcast_dynamic_shape,
                                           tf.nest.flatten(batch_shape))

    # This tf.function is not redundant with the ones on _fast_window
    # and _slow_window because the various kernels (like HMC) may invoke
    # `target_log_prob_fn` multiple times within one window.
    @tf.function(autograph=False)
    def target_log_prob_fn(*args):
        lp = pinned_model.unnormalized_log_prob(bijector.inverse(args))
        ldj = bijector.inverse_log_det_jacobian(
            args, event_ndims=[1 for _ in initial_transformed_position])
        return lp + ldj

    def step_broadcast(step_size):
        # Only apply the bijector to nested step sizes or non-scalar batches.
        if tf.nest.is_nested(step_size):
            return step_bijector(
                nest_util.broadcast_structure(
                    pinned_model.event_shape_tensor(), step_size))
        else:
            return step_size

    shard_axis_names = pinned_model.experimental_shard_axis_names
    if any(tf.nest.flatten(shard_axis_names)):
        shard_axis_names = nest.flatten_up_to(
            initial_transformed_position,
            list(pinned_model._model_flatten(shard_axis_names)))  # pylint: disable=protected-access

    else:
        # No active shard axis names
        shard_axis_names = None

    return (target_log_prob_fn, initial_transformed_position, bijector,
            step_broadcast,
            ps.convert_to_shape_tensor(batch_shape,
                                       name='batch_shape'), shard_axis_names)