Esempio n. 1
0
  def testMVNConjugateLinearUpdateSupportsBatchShape(self):
    strm = test_util.test_seed_stream()
    num_latents = 2
    num_outputs = 4
    batch_shape = [3, 1]

    prior_mean = tf.ones([num_latents])
    prior_scale = tf.eye(num_latents) * 5.
    likelihood_scale = tf.linalg.LinearOperatorLowerTriangular(
        tfb.FillScaleTriL().forward(
            tf.random.normal(
                shape=batch_shape + [int(num_outputs * (num_outputs + 1) / 2)],
                seed=strm())))
    linear_transformation = tf.random.normal(
        batch_shape + [num_outputs, num_latents], seed=strm()) * 5.
    true_latent = tf.random.normal(batch_shape + [num_latents], seed=strm())
    observation = tf.linalg.matvec(linear_transformation, true_latent)
    posterior_mean, posterior_prec = (
        tfd.mvn_conjugate_linear_update(
            prior_mean=prior_mean,
            prior_scale=prior_scale,
            linear_transformation=linear_transformation,
            likelihood_scale=likelihood_scale,
            observation=observation))

    self._mvn_linear_update_test_helper(
        prior_mean=prior_mean,
        prior_scale=prior_scale,
        linear_transformation=linear_transformation,
        likelihood_scale=likelihood_scale.to_dense(),
        observation=observation,
        candidate_posterior_mean=posterior_mean,
        candidate_posterior_prec=posterior_prec.to_dense())
Esempio n. 2
0
    def testComputesCorrectValues(self):
        shift = 1.61803398875
        x = np.float32(np.array([-1, .5, 2]))
        y = np.float32(
            np.array([[np.exp(2) + shift, 0.], [.5, np.exp(-1) + shift]]))

        b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=shift)

        y_ = self.evaluate(b.forward(x))
        self.assertAllClose(y, y_, rtol=1e-4)

        x_ = self.evaluate(b.inverse(y))
        self.assertAllClose(x, x_, rtol=1e-4)
Esempio n. 3
0
    def testInvertible(self):

        # Generate random inputs from an unconstrained space, with
        # event size 6 to specify 3x3 triangular matrices.
        batch_shape = [2, 1]
        x = np.random.randn(*(batch_shape + [6])).astype(np.float32)
        b = tfb.FillScaleTriL(diag_bijector=tfb.Softplus(), diag_shift=3.14159)
        y = self.evaluate(b.forward(x))
        self.assertAllEqual(y.shape, batch_shape + [3, 3])

        x_ = self.evaluate(b.inverse(y))
        self.assertAllClose(x, x_, rtol=1e-4)

        fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1))
        ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
        self.assertAllClose(fldj, -ildj, rtol=1e-4)
Esempio n. 4
0
 def new(params,
         event_shape,
         covariance,
         loc_activation=None,
         scale_activation=None,
         validate_args=False,
         name=None):
   r"""Create the distribution instance from a `params` vector."""
   covariance = str(covariance).lower().strip()
   event_size = tf.reduce_prod(event_shape)
   assert covariance in ('full', 'tril', 'diag'), \
     f"No support for given covariance: '{covariance}'"
   if name is None:
     name = f"MultivariateNormal{covariance.capitalize()}"
   # parameters
   params = tf.convert_to_tensor(value=params, name='params')
   loc = params[..., :event_size]
   if loc_activation is not None:
     loc = loc_activation(loc)
   scale = params[..., event_size:]
   ### the distribution
   if covariance == 'tril':
     scale_tril = tfb.FillScaleTriL(
         diag_bijector=scale_activation,
         diag_shift=np.array(1e-5, params.dtype.as_numpy_dtype()),
         validate_args=validate_args,
     )
     return tfd.MultivariateNormalTriL(loc=loc,
                                       scale_tril=scale_tril(scale),
                                       validate_args=validate_args,
                                       name=name)
   elif covariance == 'diag':
     # NOTE: never forget to use activation softplus for the scale,
     # or you will suffer
     if scale_activation is None:
       scale_activation = tf.nn.softplus
     return tfd.MultivariateNormalDiag(loc=loc,
                                       scale_diag=scale_activation(scale),
                                       validate_args=validate_args,
                                       name=name)
   elif covariance == 'full':
     raise NotImplementedError(
         'MVN full covariance is deprecated, '
         'use `scale_tril=tf.linalg.cholesky(covariance_matrix)` instead')
Esempio n. 5
0
 def testJacobian(self):
     cholesky_to_vector = tfb.Invert(
         tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None))
     bijector = tfb.CholeskyToInvCholesky()
     for x in [
             np.array([[2.]], dtype=np.float64),
             np.array([[2., 0.], [3., 4.]], dtype=np.float64),
             np.array([[2., 0., 0.], [3., 4., 0.], [5., 6., 7.]],
                      dtype=np.float64)
     ]:
         fldj = bijector.forward_log_det_jacobian(x, event_ndims=2)
         fldj_numerical = self._get_fldj_numerical(
             bijector,
             x,
             event_ndims=2,
             input_to_vector=cholesky_to_vector,
             output_to_vector=cholesky_to_vector)
         fldj_, fldj_numerical_ = self.evaluate([fldj, fldj_numerical])
         self.assertAllClose(fldj_, fldj_numerical_, rtol=1e-2)
Esempio n. 6
0
 def new(params,
         event_shape,
         covariance,
         loc_activation=tf.identity,
         scale_activation=softplus1,
         validate_args=False,
         name=None):
   r"""Create the distribution instance from a `params` vector."""
   covariance = str(covariance).lower().strip()
   event_size = tf.reduce_prod(event_shape)
   assert covariance in ('full', 'tril', 'diag'), \
   "No support for given covariance: '%s'" % covariance
   if name is None:
     name = "MultivariateNormal%s" % covariance.capitalize()
   # parameters
   params = tf.convert_to_tensor(value=params, name='params')
   loc = loc_activation(params[..., :event_size])
   scale = scale_activation(params[..., event_size:])
   ### the distribution
   if covariance == 'tril':
     scale_tril = tfb.FillScaleTriL(
         diag_shift=np.array(1e-5, params.dtype.as_numpy_dtype()),
         validate_args=validate_args,
     )
     return tfd.MultivariateNormalTriL(loc=loc,
                                       scale_tril=scale_tril(scale),
                                       validate_args=validate_args,
                                       name=name)
   elif covariance == 'diag':
     return tfd.MultivariateNormalDiag(loc=loc,
                                       scale_diag=scale,
                                       validate_args=validate_args,
                                       name=name)
   elif covariance == 'full':
     return tfd.MultivariateNormalFullCovariance(loc=loc,
                                                 covariance_matrix=tf.reshape(
                                                     scale,
                                                     (event_size, event_size)),
                                                 validate_args=validate_args,
                                                 name=name)
Esempio n. 7
0
batch_size = 32
max_iter = 50000
encoder = vi.NetConf([256, 256, 256], flatten_inputs=True, name='Encoder')
decoder = vi.NetConf([256, 256, 256], flatten_inputs=True, name='Decoder')
encoded_size = 16
posteriors_info = [
    ('gaussian', 'mvndiag', 'mvntril'),
    (
        D.Sample(D.Normal(loc=0., scale=1.),
                 sample_shape=encoded_size,
                 name='independent'),
        D.MultivariateNormalDiag(loc=tf.zeros(encoded_size),
                                 scale_diag=tf.ones(encoded_size),
                                 name='mvndiag'),
        D.MultivariateNormalTriL(loc=tf.zeros(encoded_size),
                                 scale_tril=bj.FillScaleTriL()(tf.ones(
                                     encoded_size * (encoded_size + 1) // 2)),
                                 name='mvntril'),
        D.MixtureSameFamily(
            components_distribution=D.MultivariateNormalDiag(
                loc=tf.zeros([10, encoded_size]),
                scale_diag=tf.ones([10, encoded_size])),
            mixture_distribution=D.Categorical(logits=tf.fill([10], 1.0 / 10)),
            name='gmm10'),
        D.MixtureSameFamily(components_distribution=D.MultivariateNormalDiag(
            loc=tf.zeros([100, encoded_size]),
            scale_diag=tf.ones([100, encoded_size])),
                            mixture_distribution=D.Categorical(
                                logits=tf.fill([100], 1.0 / 100)),
                            name='gmm100'),
    ),
    ('identity', 'relu', 'softplus', 'softplus1'),
Esempio n. 8
0
def build_trainable_highway_flow(width,
                                 residual_fraction_initial_value=0.5,
                                 activation_fn=None,
                                 gate_first_n=None,
                                 seed=None,
                                 validate_args=False):
    """Builds a HighwayFlow parameterized by trainable variables.

  The variables are transformed to enforce the following parameter constraints:

  - `residual_fraction` is bounded between 0 and 1.
  - `upper_diagonal_weights_matrix` is a randomly initialized (lower) diagonal
     matrix with positive diagonal of size `width x width`.
  - `lower_diagonal_weights_matrix` is a randomly initialized lower diagonal
     matrix with ones on the diagonal of size `width x width`;
  - `bias` is a randomly initialized vector of size `width`.

  Args:
    width: Input dimension of the bijector.
    residual_fraction_initial_value: Initial value for gating parameter, must be
      between 0 and 1.
    activation_fn: Callable invertible activation function
      (e.g., `tf.nn.softplus`), or `None`.
    gate_first_n: Decides which part of the input should be gated (useful for
      example when using auxiliary variables).
    seed: Seed for random initialization of the weights.
    validate_args: Python `bool`. Whether to validate input with runtime
        assertions.
        Default value: `False`.

  Returns:
    trainable_highway_flow: The initialized bijector.
  """

    residual_fraction_initial_value = tf.convert_to_tensor(
        residual_fraction_initial_value,
        dtype_hint=tf.float32,
        name='residual_fraction_initial_value')
    dtype = residual_fraction_initial_value.dtype

    bias_seed, upper_seed, lower_seed = samplers.split_seed(seed, n=3)
    lower_bijector = tfb.Chain([
        tfb.TransformDiagonal(diag_bijector=tfb.Shift(1.)),
        tfb.Pad(paddings=[(1, 0), (0, 1)]),
        tfb.FillTriangular()
    ])
    unconstrained_lower_initial_values = samplers.normal(
        shape=lower_bijector.inverse_event_shape([width, width]),
        mean=0.,
        stddev=.01,
        seed=lower_seed)
    upper_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus(),
                                       diag_shift=None)
    unconstrained_upper_initial_values = samplers.normal(
        shape=upper_bijector.inverse_event_shape([width, width]),
        mean=0.,
        stddev=.01,
        seed=upper_seed)

    return HighwayFlow(residual_fraction=util.TransformedVariable(
        initial_value=residual_fraction_initial_value,
        bijector=tfb.Sigmoid(),
        dtype=dtype),
                       activation_fn=activation_fn,
                       bias=tf.Variable(samplers.normal((width, ),
                                                        mean=0.,
                                                        stddev=0.01,
                                                        seed=bias_seed),
                                        dtype=dtype),
                       upper_diagonal_weights_matrix=util.TransformedVariable(
                           initial_value=upper_bijector.forward(
                               unconstrained_upper_initial_values),
                           bijector=upper_bijector,
                           dtype=dtype),
                       lower_diagonal_weights_matrix=util.TransformedVariable(
                           initial_value=lower_bijector.forward(
                               unconstrained_lower_initial_values),
                           bijector=lower_bijector,
                           dtype=dtype),
                       gate_first_n=gate_first_n,
                       validate_args=validate_args)