def _forward(self, x, **kwargs):
        static_event_size = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(
                x.shape, self._event_ndims)[-self._event_ndims:])

        if self._unroll_loop:
            if not static_event_size:
                raise ValueError(
                    "The final {} dimensions of `x` must be known at graph "
                    "construction time if `unroll_loop=True`. `x.shape: {!r}`".
                    format(self._event_ndims, x.shape))
            y = tf.zeros_like(x, name="y0")

            for _ in range(static_event_size):
                shift, log_scale = self._shift_and_log_scale_fn(y, **kwargs)
                # next_y = scale * x + shift
                next_y = x
                if log_scale is not None:
                    next_y *= tf.exp(log_scale)
                if shift is not None:
                    next_y += shift
                y = next_y
            return y

        event_size = tf.reduce_prod(input_tensor=tf.shape(
            input=x)[-self._event_ndims:])
        y0 = tf.zeros_like(x, name="y0")
        # call the template once to ensure creation
        _ = self._shift_and_log_scale_fn(y0, **kwargs)

        def _loop_body(index, y0):
            """While-loop body for autoregression calculation."""
            # Set caching device to avoid re-getting the tf.Variable for every while
            # loop iteration.
            with tf.compat.v1.variable_scope(
                    tf.compat.v1.get_variable_scope()) as vs:
                if vs.caching_device is None and not tf.executing_eagerly():
                    vs.set_caching_device(lambda op: op.device)
                shift, log_scale = self._shift_and_log_scale_fn(y0, **kwargs)
            y = x
            if log_scale is not None:
                y *= tf.exp(log_scale)
            if shift is not None:
                y += shift
            return index + 1, y

        # If the event size is available at graph construction time, we can inform
        # the graph compiler of the maximum number of steps. If not,
        # static_event_size will be None, and the maximum_iterations argument will
        # have no effect.
        _, y = tf.while_loop(cond=lambda index, _: index < event_size,
                             body=_loop_body,
                             loop_vars=(0, y0),
                             maximum_iterations=static_event_size)
        return y
Esempio n. 2
0
    def __init__(self,
                 mean_direction,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='VonMisesFisher'):
        """Creates a new `VonMisesFisher` instance.

    Args:
      mean_direction: Floating-point `Tensor` with shape [B1, ... Bn, D].
        A unit vector indicating the mode of the distribution, or the
        unit-normalized direction of the mean. NOTE: `D` is currently
        restricted to <= 5.
      concentration: Floating-point `Tensor` having batch shape [B1, ... Bn]
        broadcastable with `mean_direction`. The level of concentration of
        samples around the `mean_direction`. `concentration=0` indicates a
        uniform distribution over the unit hypersphere, and `concentration=+inf`
        indicates a `Deterministic` distribution (delta function) at
        `mean_direction`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([mean_direction, concentration],
                                            tf.float32)
            self._mean_direction = tensor_util.convert_nonref_to_tensor(
                mean_direction, name='mean_direction', dtype=dtype)
            self._concentration = tensor_util.convert_nonref_to_tensor(
                concentration, name='concentration', dtype=dtype)

            static_event_dim = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(self._mean_direction.shape,
                                                    1)[-1])

            # mean_direction is always reparameterized.
            # concentration is only for event_dim==3, via an inversion sampler.
            reparameterization_type = (reparameterization.FULLY_REPARAMETERIZED
                                       if static_event_dim == 3 else
                                       reparameterization.NOT_REPARAMETERIZED)
            super(VonMisesFisher, self).__init__(
                dtype=self._concentration.dtype,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                reparameterization_type=reparameterization_type,
                parameters=parameters,
                name=name)
Esempio n. 3
0
    def _forward_event_shape(self, input_shape):
        input_shape = tensorshape_util.with_rank_at_least(input_shape, 1)
        static_block_sizes = tf.get_static_value(self.block_sizes)
        if static_block_sizes is None:
            return tensorshape_util.concatenate(input_shape[:-1], [None])

        output_size = sum(
            b.forward_event_shape([bs])[0]
            for b, bs in zip(self.bijectors, static_block_sizes))

        return tensorshape_util.concatenate(input_shape[:-1], [output_size])
Esempio n. 4
0
  def test_with_rank_list_tuple(self):
    with self.assertRaises(ValueError):
      tensorshape_util.with_rank([2], 2)

    with self.assertRaises(ValueError):
      tensorshape_util.with_rank((2,), 2)

    self.assertAllEqual(
        (2, 1),
        tensorshape_util.with_rank((2, 1), 2))
    self.assertAllEqual(
        [2, 1],
        tensorshape_util.with_rank([2, 1], 2))

    self.assertAllEqual(
        (2, 3, 4),
        tensorshape_util.with_rank_at_least((2, 3, 4), 2))
    self.assertAllEqual(
        [2, 3, 4],
        tensorshape_util.with_rank_at_least([2, 3, 4], 2))
    def test_with_rank_ndarray(self):
        x = np.array([2], dtype=np.int32)
        with self.assertRaises(ValueError):
            tensorshape_util.with_rank(x, 2)

        x = np.array([2, 3, 4], dtype=np.int32)
        y = tensorshape_util.with_rank(x, 3)
        self.assertAllEqual(x, y)

        x = np.array([2, 3, 4, 5], dtype=np.int32)
        y = tensorshape_util.with_rank_at_least(x, 3)
        self.assertAllEqual(x, y)
Esempio n. 6
0
  def _cache_input_depth(self, x):
    if self._input_depth is None:
      self._input_depth = tf.compat.dimension_value(
          tensorshape_util.with_rank_at_least(x.shape, 1)[-1])
      if self._input_depth is None:
        raise NotImplementedError(
            'Rightmost dimension must be known prior to graph execution.')

      if abs(self._masked_size) >= self._input_depth:
        raise ValueError(
            'Number of masked units {} must be smaller than the event size {}.'
            .format(self._masked_size, self._input_depth))
Esempio n. 7
0
def interpolate_scale(grid, scale):
  """Helper which interpolates between two scales."""
  if len(scale) != 2:
    raise NotImplementedError("Currently only bimixtures are supported; "
                              "len(scale)={} is not 2.".format(len(scale)))
  deg = tf.compat.dimension_value(
      tensorshape_util.with_rank_at_least(grid.shape, 1)[-1])
  if deg is None:
    raise ValueError("Num quadrature grid points must be known prior "
                     "to graph execution.")
  with tf.name_scope("interpolate_scale"):
    return [linop_add_lib.add_operators([
        linop_scale(grid[..., k, q], s)
        for k, s in enumerate(scale)
    ])[0] for q in range(deg)]
 def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims,
                                **distribution_kwargs):
   """Finish computation of prob on one element of the inverse image."""
   x = self._maybe_rotate_dims(x, rotate_right=True)
   prob = self.distribution.prob(x, **distribution_kwargs)
   if self._is_maybe_event_override:
     prob = tf.reduce_prod(prob, axis=self._reduce_event_indices)
   prob = prob * tf.exp(tf.cast(ildj, prob.dtype))
   if self._is_maybe_event_override and isinstance(event_ndims, int):
     tensorshape_util.set_shape(
         prob,
         tf.broadcast_static_shape(
             tensorshape_util.with_rank_at_least(y.shape, 1)[:-event_ndims],
             self.batch_shape))
   return prob
Esempio n. 9
0
 def _expand_mix_distribution_probs(self):
     p = self.mixture_distribution.probs_parameter()  # [B, deg]
     deg = tf.compat.dimension_value(
         tensorshape_util.with_rank_at_least(p.shape, 1)[-1])
     if deg is None:
         deg = tf.shape(input=p)[-1]
     event_ndims = tensorshape_util.rank(self.event_shape)
     if event_ndims is None:
         event_ndims = tf.shape(input=self.event_shape_tensor())[0]
     expand_shape = tf.concat([
         self.mixture_distribution.batch_shape_tensor(),
         tf.ones([event_ndims], dtype=tf.int32),
         [deg],
     ],
                              axis=0)
     return tf.reshape(p, shape=expand_shape)
 def _bijector_fn(x, **condition_kwargs):
     if conditioning is not None:
         print(x, conditioning)
         x = tf.concat([conditioning, x], axis=-1)
         cond_depth = tf.compat.dimension_value(
             tensorshape_util.with_rank_at_least(
                 conditioning.shape, 1)[-1])
     else:
         cond_depth = 0
     params = shift_and_log_scale_fn(x, **condition_kwargs)
     if tf.is_tensor(params):
         shift, log_scale = tf.unstack(params, num=2, axis=-1)
     else:
         shift, log_scale = params
     shift = shift[..., cond_depth:]
     log_scale = log_scale[..., cond_depth:]
     return affine_scalar.AffineScalar(shift=shift,
                                       log_scale=log_scale)
Esempio n. 11
0
 def _fn(x):
     """MADE parameterized via `masked_autoregressive_default_template`."""
     # TODO(b/67594795): Better support of dynamic shape.
     input_depth = tf.compat.dimension_value(
         tensorshape_util.with_rank_at_least(x.shape, 1)[-1])
     if input_depth is None:
         raise NotImplementedError(
             'Rightmost dimension must be known prior to graph execution.'
         )
     input_shape = (np.int32(tensorshape_util.as_list(x.shape))
                    if tensorshape_util.is_fully_defined(x.shape) else
                    tf.shape(x))
     if tensorshape_util.rank(x.shape) == 1:
         x = x[tf.newaxis, ...]
     for i, units in enumerate(hidden_layers):
         x = masked_dense(
             inputs=x,
             units=units,
             num_blocks=input_depth,
             exclusive=True if i == 0 else False,
             activation=activation,
             *args,  # pylint: disable=keyword-arg-before-vararg
             **kwargs)
     x = masked_dense(
         inputs=x,
         units=(1 if shift_only else 2) * input_depth,
         num_blocks=input_depth,
         activation=None,
         *args,  # pylint: disable=keyword-arg-before-vararg
         **kwargs)
     if shift_only:
         x = tf.reshape(x, shape=input_shape)
         return x, None
     x = tf.reshape(x, shape=tf.concat([input_shape, [2]], axis=0))
     shift, log_scale = tf.unstack(x, num=2, axis=-1)
     which_clip = (tf.clip_by_value if log_scale_clip_gradient else
                   clip_by_value_preserve_gradient)
     log_scale = which_clip(log_scale, log_scale_min_clip,
                            log_scale_max_clip)
     return shift, log_scale
    def _forward(self, x, **kwargs):
        static_event_size = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(
                x.shape, self._event_ndims)[-self._event_ndims:])

        if self._unroll_loop:
            if not static_event_size:
                raise ValueError(
                    'The final {} dimensions of `x` must be known at graph '
                    'construction time if `unroll_loop=True`. `x.shape: {!r}`'.
                    format(self._event_ndims, x.shape))
            y = tf.zeros_like(x, name='y0')

            for _ in range(static_event_size):
                y = self._bijector_fn(y, **kwargs).forward(x)
            return y

        event_size = tf.reduce_prod(tf.shape(x)[-self._event_ndims:])
        y0 = tf.zeros_like(x, name='y0')
        # call the template once to ensure creation
        if not tf.executing_eagerly():
            _ = self._bijector_fn(y0, **kwargs).forward(y0)

        def _loop_body(y0):
            """While-loop body for autoregression calculation."""
            # Set caching device to avoid re-getting the tf.Variable for every while
            # loop iteration.
            with tf1.variable_scope(tf1.get_variable_scope()) as vs:
                if vs.caching_device is None and not tf.executing_eagerly():
                    vs.set_caching_device(lambda op: op.device)
                bijector = self._bijector_fn(y0, **kwargs)
            y = bijector.forward(x)
            return (y, )

        (y, ) = tf.while_loop(cond=lambda _: True,
                              body=_loop_body,
                              loop_vars=(y0, ),
                              maximum_iterations=event_size)
        return y
Esempio n. 13
0
    def __init__(self,
                 mean_direction,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='VonMisesFisher',
                 check_dim=True):

        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([mean_direction, concentration],
                                            tf.float32)
            self._mean_direction = tensor_util.convert_nonref_to_tensor(
                mean_direction, name='mean_direction', dtype=dtype)
            self._concentration = tensor_util.convert_nonref_to_tensor(
                concentration, name='concentration', dtype=dtype)

            static_event_dim = tf.compat.dimension_value(
                tensorshape_util.with_rank_at_least(self._mean_direction.shape,
                                                    1)[-1])
            if check_dim == True:
                if static_event_dim is not None and static_event_dim > 5:
                    raise ValueError(
                        'von Mises-Fisher ndims > 5 is not currently '
                        'supported')

            # mean_direction is always reparameterized.
            # concentration is only for event_dim==3, via an inversion sampler.
            reparameterization_type = (reparameterization.FULLY_REPARAMETERIZED
                                       if static_event_dim == 3 else
                                       reparameterization.NOT_REPARAMETERIZED)
            super(tfd.VonMisesFisher, self).__init__(
                dtype=self._concentration.dtype,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                reparameterization_type=reparameterization_type,
                parameters=parameters,
                name=name)
Esempio n. 14
0
def maybe_check_quadrature_param(param, name, validate_args):
  """Helper which checks validity of `loc` and `scale` init args."""
  with tf.name_scope("check_" + name):
    assertions = []
    if tensorshape_util.rank(param.shape) is not None:
      if tensorshape_util.rank(param.shape) == 0:
        raise ValueError("Mixing params must be a (batch of) vector; "
                         "{}.rank={} is not at least one.".format(
                             name, tensorshape_util.rank(param.shape)))
    elif validate_args:
      assertions.append(
          assert_util.assert_rank_at_least(
              param,
              1,
              message=("Mixing params must be a (batch of) vector; "
                       "{}.rank is not at least one.".format(name))))

    # TODO(jvdillon): Remove once we support k-mixtures.
    if tensorshape_util.with_rank_at_least(param.shape, 1)[-1] is not None:
      if tf.compat.dimension_value(param.shape[-1]) != 1:
        raise NotImplementedError("Currently only bimixtures are supported; "
                                  "{}.shape[-1]={} is not 1.".format(
                                      name,
                                      tf.compat.dimension_value(
                                          param.shape[-1])))
    elif validate_args:
      assertions.append(
          assert_util.assert_equal(
              tf.shape(input=param)[-1],
              1,
              message=("Currently only bimixtures are supported; "
                       "{}.shape[-1] is not 1.".format(name))))

    if assertions:
      return distribution_util.with_dependencies(assertions, param)
    return param
Esempio n. 15
0
def fill_triangular_inverse(x, upper=False, name=None):
    """Creates a vector from a (batch of) triangular matrix.

  The vector is created from the lower-triangular or upper-triangular portion
  depending on the value of the parameter `upper`.

  If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
  `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.

  Example:

  ```python
  fill_triangular_inverse(
    [[4, 0, 0],
     [6, 5, 0],
     [3, 2, 1]])

  # ==> [1, 2, 3, 4, 5, 6]

  fill_triangular_inverse(
    [[1, 2, 3],
     [0, 5, 6],
     [0, 0, 4]], upper=True)

  # ==> [1, 2, 3, 4, 5, 6]
  ```

  Args:
    x: `Tensor` representing lower (or upper) triangular elements.
    upper: Python `bool` representing whether output matrix should be upper
      triangular (`True`) or lower triangular (`False`, default).
    name: Python `str`. The name to give this op.

  Returns:
    flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
      (or upper) triangular elements from `x`.
  """

    with tf.name_scope(name or 'fill_triangular_inverse'):
        x = tf.convert_to_tensor(x, name='x')
        n = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(x.shape, 2)[-1])
        if n is not None:
            n = np.int32(n)
            m = np.int32((n * (n + 1)) // 2)
            static_final_shape = tensorshape_util.concatenate(
                x.shape[:-2], [m])
        else:
            n = tf.shape(x)[-1]
            m = (n * (n + 1)) // 2
            static_final_shape = tensorshape_util.concatenate(
                tensorshape_util.with_rank_at_least(x.shape, 2)[:-2], [None])
        ndims = prefer_static.rank(x)
        if upper:
            initial_elements = x[..., 0, :]
            triangular_portion = x[..., 1:, :]
        else:
            initial_elements = tf.reverse(x[..., -1, :], axis=[ndims - 2])
            triangular_portion = x[..., :-1, :]
        rotated_triangular_portion = tf.reverse(tf.reverse(triangular_portion,
                                                           axis=[ndims - 1]),
                                                axis=[ndims - 2])
        consolidated_matrix = triangular_portion + rotated_triangular_portion
        end_sequence = tf.reshape(
            consolidated_matrix,
            tf.concat([tf.shape(x)[:-2], [n * (n - 1)]], axis=0))
        y = tf.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
        tensorshape_util.set_shape(y, static_final_shape)
        return y
Esempio n. 16
0
    def _sample_n(self, n, seed=None):
        stream = seed_stream.SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                input_tensor=self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:])
        if stride is None:
            stride = tf.reduce_prod(input_tensor=tf.shape(
                input=self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
 def _event_shape(self):
     param = self._logits if self._logits is not None else self._probs
     return tensorshape_util.with_rank_at_least(param.shape, 1)[-1:]
Esempio n. 18
0
 def _batch_shape(self):
   return tf.broadcast_static_shape(
       tensorshape_util.with_rank_at_least(self.mean_direction.shape, 1)[:-1],
       self.concentration.shape)
Esempio n. 19
0
 def _event_shape(self, scores=None):
   scores = self._scores if scores is None else scores
   return tensorshape_util.with_rank_at_least(scores.shape, 1)[-1:]
Esempio n. 20
0
 def _event_shape(self):
   return tensorshape_util.with_rank_at_least(self.logits.shape, 1)[-1:]
Esempio n. 21
0
def masked_dense(
        inputs,
        units,
        num_blocks=None,
        exclusive=False,
        kernel_initializer=None,
        reuse=None,
        name=None,
        *args,  # pylint: disable=keyword-arg-before-vararg
        **kwargs):
    """A autoregressively masked dense layer. Analogous to `tf.layers.dense`.

  See [Germain et al. (2015)][1] for detailed explanation.

  Arguments:
    inputs: Tensor input.
    units: Python `int` scalar representing the dimensionality of the output
      space.
    num_blocks: Python `int` scalar representing the number of blocks for the
      MADE masks.
    exclusive: Python `bool` scalar representing whether to zero the diagonal of
      the mask, used for the first layer of a MADE.
    kernel_initializer: Initializer function for the weight matrix.
      If `None` (default), weights are initialized using the
      `tf.glorot_random_initializer`.
    reuse: Python `bool` scalar representing whether to reuse the weights of a
      previous layer by the same name.
    name: Python `str` used to describe ops managed by this function.
    *args: `tf.layers.dense` arguments.
    **kwargs: `tf.layers.dense` keyword arguments.

  Returns:
    Output tensor.

  Raises:
    NotImplementedError: if rightmost dimension of `inputs` is unknown prior to
      graph execution.

  #### References

  [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE:
       Masked Autoencoder for Distribution Estimation. In _International
       Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509
  """
    # TODO(b/67594795): Better support of dynamic shape.
    input_depth = tf.compat.dimension_value(
        tensorshape_util.with_rank_at_least(inputs.shape, 1)[-1])
    if input_depth is None:
        raise NotImplementedError(
            'Rightmost dimension must be known prior to graph execution.')

    mask = _gen_mask(num_blocks, input_depth, units,
                     MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T

    if kernel_initializer is None:
        kernel_initializer = tf1.glorot_normal_initializer()

    def masked_initializer(shape, dtype=None, partition_info=None):
        return mask * kernel_initializer(shape, dtype, partition_info)

    with tf.name_scope(name or 'masked_dense'):
        layer = tf1.layers.Dense(
            units,
            kernel_initializer=masked_initializer,
            kernel_constraint=lambda x: mask * x,
            name=name,
            dtype=dtype_util.base_dtype(inputs.dtype),
            _scope=name,
            _reuse=reuse,
            *args,  # pylint: disable=keyword-arg-before-vararg
            **kwargs)
        return layer.apply(inputs)
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init and not dtype_util.is_integer(
                self.mixture_distribution.dtype):
            raise ValueError(
                '`mixture_distribution.dtype` ({}) is not over integers'.
                format(dtype_util.name(self.mixture_distribution.dtype)))

        if tensorshape_util.rank(
                self.mixture_distribution.event_shape) is not None:
            if tensorshape_util.rank(
                    self.mixture_distribution.event_shape) != 0:
                raise ValueError(
                    '`mixture_distribution` must have scalar `event_dim`s')
        elif self.validate_args:
            assertions += [
                assert_util.assert_equal(
                    tf.size(self.mixture_distribution.event_shape_tensor()),
                    0,
                    message=
                    '`mixture_distribution` must have scalar `event_dim`s'),
            ]

        # pylint: disable=protected-access
        mixture_dist_param = (self.mixture_distribution._probs
                              if self.mixture_distribution._logits is None else
                              self.mixture_distribution._logits)
        km = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                1)[-1])
        kc = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(
                self.components_distribution.batch_shape, 1)[-1])
        component_bst = None
        if km is not None and kc is not None:
            if km != kc:
                raise ValueError(
                    '`mixture_distribution` components ({}) does not '
                    'equal `components_distribution.batch_shape[-1]` '
                    '({})'.format(km, kc))
        elif self.validate_args:
            if km is None:
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                km = tf.shape(mixture_dist_param)[-1]
            if kc is None:
                component_bst = self.components_distribution.batch_shape_tensor(
                )
                kc = component_bst[-1]
            assertions += [
                assert_util.assert_equal(
                    km,
                    kc,
                    message=(
                        '`mixture_distribution` components does not equal '
                        '`components_distribution.batch_shape[-1]`')),
            ]

        mdbs = self.mixture_distribution.batch_shape
        cdbs = tensorshape_util.with_rank_at_least(
            self.components_distribution.batch_shape, 1)[:-1]
        if (tensorshape_util.is_fully_defined(mdbs)
                and tensorshape_util.is_fully_defined(cdbs)):
            if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                raise ValueError(
                    '`mixture_distribution.batch_shape` (`{}`) is not '
                    'compatible with `components_distribution.batch_shape` '
                    '(`{}`)'.format(tensorshape_util.as_list(mdbs),
                                    tensorshape_util.as_list(cdbs)))
        elif self.validate_args:
            if not tensorshape_util.is_fully_defined(mdbs):
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                mdbs = tf.shape(mixture_dist_param)[:-1]
            if not tensorshape_util.is_fully_defined(cdbs):
                if component_bst is None:
                    component_bst = self.components_distribution.batch_shape_tensor(
                    )
                cdbs = component_bst[:-1]
            assertions += [
                assert_util.assert_equal(
                    distribution_utils.pick_vector(
                        tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs),
                    cdbs,
                    message=(
                        '`mixture_distribution.batch_shape` is not '
                        'compatible with `components_distribution.batch_shape`'
                    ))
            ]

        return assertions
 def _batch_shape(self):
     return tensorshape_util.with_rank_at_least(
         self.components_distribution.batch_shape, 1)[:-1]
Esempio n. 24
0
 def _batch_shape(self):
     return tensorshape_util.with_rank_at_least(self._mean_val.shape,
                                                1)[:-1]
Esempio n. 25
0
    def __init__(self,
                 logits=None,
                 probs=None,
                 dtype=tf.int32,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Categorical"):
        """Initialize Categorical distributions using class log-probabilities.

    Args:
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of logits for each class. Only one of `logits` or
        `probs` should be passed in.
      probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of probabilities for each class. Only one of
        `logits` or `probs` should be passed in.
      dtype: The type of the event samples (default: int32).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits,
                probs=probs,
                validate_args=validate_args,
                multidimensional=True,
                name=name)

            if validate_args:
                self._logits = distribution_util.embed_check_categorical_event_shape(
                    self._logits)

            logits_shape_static = tensorshape_util.with_rank_at_least(
                self._logits.shape, 1)
            if tensorshape_util.rank(logits_shape_static) is not None:
                self._batch_rank = tf.convert_to_tensor(
                    value=tensorshape_util.rank(logits_shape_static) - 1,
                    dtype=tf.int32,
                    name="batch_rank")
            else:
                with tf.name_scope("batch_rank"):
                    self._batch_rank = tf.rank(self._logits) - 1

            logits_shape = tf.shape(input=self._logits, name="logits_shape")
            num_categories = tf.compat.dimension_value(logits_shape_static[-1])
            if num_categories is not None:
                self._num_categories = tf.convert_to_tensor(
                    value=num_categories,
                    dtype=tf.int32,
                    name="num_categories")
            else:
                with tf.name_scope("num_categories"):
                    self._num_categories = logits_shape[self._batch_rank]

            if logits_shape_static[:-1].is_fully_defined():
                self._batch_shape_val = tf.constant(
                    logits_shape_static[:-1].as_list(),
                    dtype=tf.int32,
                    name="batch_shape")
            else:
                with tf.name_scope("batch_shape"):
                    self._batch_shape_val = logits_shape[:-1]
        super(Categorical, self).__init__(
            dtype=dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._logits, self._probs],
            name=name)
Esempio n. 26
0
def fill_triangular(x, upper=False, name=None):
    """Creates a (batch of) triangular matrix from a vector of inputs.

  Created matrix can be lower- or upper-triangular. (It is more efficient to
  create the matrix as upper or lower, rather than transpose.)

  Triangular matrix elements are filled in a clockwise spiral. See example,
  below.

  If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is
  `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
  `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.

  Example:

  ```python
  fill_triangular([1, 2, 3, 4, 5, 6])
  # ==> [[4, 0, 0],
  #      [6, 5, 0],
  #      [3, 2, 1]]

  fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
  # ==> [[1, 2, 3],
  #      [0, 5, 6],
  #      [0, 0, 4]]
  ```

  The key trick is to create an upper triangular matrix by concatenating `x`
  and a tail of itself, then reshaping.

  Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M`
  from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x`
  contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5`
  (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with
  the first (`n = 5`) elements removed and reversed:

  ```python
  x = np.arange(15) + 1
  xc = np.concatenate([x, x[5:][::-1]])
  # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13,
  #            12, 11, 10, 9, 8, 7, 6])

  # (We add one to the arange result to disambiguate the zeros below the
  # diagonal of our upper-triangular matrix from the first entry in `x`.)

  # Now, when reshapedlay this out as a matrix:
  y = np.reshape(xc, [5, 5])
  # ==> array([[ 1,  2,  3,  4,  5],
  #            [ 6,  7,  8,  9, 10],
  #            [11, 12, 13, 14, 15],
  #            [15, 14, 13, 12, 11],
  #            [10,  9,  8,  7,  6]])

  # Finally, zero the elements below the diagonal:
  y = np.triu(y, k=0)
  # ==> array([[ 1,  2,  3,  4,  5],
  #            [ 0,  7,  8,  9, 10],
  #            [ 0,  0, 13, 14, 15],
  #            [ 0,  0,  0, 12, 11],
  #            [ 0,  0,  0,  0,  6]])
  ```

  From this example we see that the resuting matrix is upper-triangular, and
  contains all the entries of x, as desired. The rest is details:
  - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills
    `n / 2` rows and half of an additional row), but the whole scheme still
    works.
  - If we want a lower triangular matrix instead of an upper triangular,
    we remove the first `n` elements from `x` rather than from the reversed
    `x`.

  For additional comparisons, a pure numpy version of this function can be found
  in `distribution_util_test.py`, function `_fill_triangular`.

  Args:
    x: `Tensor` representing lower (or upper) triangular elements.
    upper: Python `bool` representing whether output matrix should be upper
      triangular (`True`) or lower triangular (`False`, default).
    name: Python `str`. The name to give this op.

  Returns:
    tril: `Tensor` with lower (or upper) triangular elements filled from `x`.

  Raises:
    ValueError: if `x` cannot be mapped to a triangular matrix.
  """

    with tf.name_scope(name or 'fill_triangular'):
        x = tf.convert_to_tensor(x, name='x')
        m = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(x.shape, 1)[-1])
        if m is not None:
            # Formula derived by solving for n: m = n(n+1)/2.
            m = np.int32(m)
            n = np.sqrt(0.25 + 2. * m) - 0.5
            if n != np.floor(n):
                raise ValueError(
                    'Input right-most shape ({}) does not '
                    'correspond to a triangular matrix.'.format(m))
            n = np.int32(n)
            static_final_shape = tensorshape_util.concatenate(
                x.shape[:-1], [n, n])
        else:
            m = tf.shape(x)[-1]
            # For derivation, see above. Casting automatically lops off the 0.5, so we
            # omit it.  We don't validate n is an integer because this has
            # graph-execution cost; an error will be thrown from the reshape, below.
            n = tf.cast(tf.sqrt(0.25 + tf.cast(2 * m, dtype=tf.float32)),
                        dtype=tf.int32)
            static_final_shape = tensorshape_util.concatenate(
                tensorshape_util.with_rank_at_least(x.shape, 1)[:-1],
                [None, None])

        # Try it out in numpy:
        #  n = 3
        #  x = np.arange(n * (n + 1) / 2)
        #  m = x.shape[0]
        #  n = np.int32(np.sqrt(.25 + 2 * m) - .5)
        #  x_tail = x[(m - (n**2 - m)):]
        #  np.concatenate([x_tail, x[::-1]], 0).reshape(n, n)  # lower
        #  # ==> array([[3, 4, 5],
        #               [5, 4, 3],
        #               [2, 1, 0]])
        #  np.concatenate([x, x_tail[::-1]], 0).reshape(n, n)  # upper
        #  # ==> array([[0, 1, 2],
        #               [3, 4, 5],
        #               [5, 4, 3]])
        #
        # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
        # correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
        # Furthermore observe that:
        #   m - (n**2 - m)
        #   = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
        #   = 2 (n**2 / 2 + n / 2) - n**2
        #   = n**2 + n - n**2
        #   = n
        ndims = prefer_static.rank(x)
        if upper:
            x_list = [x, tf.reverse(x[..., n:], axis=[ndims - 1])]
        else:
            x_list = [x[..., n:], tf.reverse(x, axis=[ndims - 1])]
        new_shape = (tensorshape_util.as_list(static_final_shape)
                     if tensorshape_util.is_fully_defined(static_final_shape)
                     else tf.concat([tf.shape(x)[:-1], [n, n]], axis=0))
        x = tf.reshape(tf.concat(x_list, axis=-1), new_shape)
        x = tf.linalg.band_part(x,
                                num_lower=(0 if upper else -1),
                                num_upper=(-1 if upper else 0))
        tensorshape_util.set_shape(x, static_final_shape)
        return x
Esempio n. 27
0
  def __init__(
      self,
      temperature,
      logits=None,
      probs=None,
      validate_args=False,
      allow_nan_stats=True,
      name='ExpRelaxedOneHotCategorical'):
    """Initialize ExpRelaxedOneHotCategorical using class log-probabilities.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of ExpRelaxedCategorical distributions. The temperature should
        be positive.
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
        of a set of ExpRelaxedCategorical distributions. The first
        `N - 1` dimensions index into a batch of independent distributions and
        the last dimension represents a vector of logits for each class. Only
        one of `logits` or `probs` should be passed in.
      probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
        of a set of ExpRelaxedCategorical distributions. The first
        `N - 1` dimensions index into a batch of independent distributions and
        the last dimension represents a vector of probabilities for each
        class. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:

      dtype = dtype_util.common_dtype([logits, probs, temperature], tf.float32)
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          name=name,
          logits=logits,
          probs=probs,
          validate_args=validate_args,
          multidimensional=True,
          dtype=dtype)

      with tf.control_dependencies(
          [assert_util.assert_positive(temperature)] if validate_args else []):
        self._temperature = tf.convert_to_tensor(
            temperature, name='temperature', dtype=dtype)
        self._temperature_2d = tf.reshape(
            self._temperature, [-1, 1], name='temperature_2d')

      logits_shape_static = tensorshape_util.with_rank_at_least(
          self._logits.shape, 1)
      if tensorshape_util.rank(logits_shape_static) is not None:
        self._batch_rank = tf.convert_to_tensor(
            tensorshape_util.rank(logits_shape_static) - 1,
            dtype=tf.int32,
            name='batch_rank')
      else:
        with tf.name_scope('batch_rank'):
          self._batch_rank = tf.rank(self._logits) - 1

      with tf.name_scope('event_size'):
        self._event_size = tf.shape(self._logits)[-1]

    super(ExpRelaxedOneHotCategorical, self).__init__(
        dtype=dtype,
        reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._logits, self._probs, self._temperature],
        name=name)
Esempio n. 28
0
 def _batch_shape(self):
     return tensorshape_util.with_rank_at_least(
         tf.broadcast_static_shape(
             tf.TensorShape(self._total_count.shape).concatenate([1]),
             tf.TensorShape(self._concentration.shape)), 1)[:-1]
Esempio n. 29
0
 def _event_shape(self):
     return tensorshape_util.with_rank_at_least(self._mean_val.shape,
                                                1)[-1:]
Esempio n. 30
0
 def _event_shape(self):
     # Event shape depends only on concentration, not total_count.
     return tensorshape_util.with_rank_at_least(self.concentration.shape,
                                                1)[-1:]