Esempio n. 1
0
def make_kernel_bias_posterior_mvn_diag(kernel_shape,
                                        bias_shape,
                                        dtype=tf.float32,
                                        kernel_initializer=None,
                                        bias_initializer=None):
    """Create learnable posterior for Variational layers with kernel and bias."""
    if kernel_initializer is None:
        kernel_initializer = tf.initializers.glorot_normal()
    if bias_initializer is None:
        bias_initializer = tf.initializers.glorot_normal()
    make_loc = lambda shape, init, name: tf.Variable(  # pylint: disable=g-long-lambda
        init(shape, dtype=dtype),
        name=name + '_loc')
    make_scale = lambda shape, name: TransformedVariable(  # pylint: disable=g-long-lambda
        tf.ones(shape, dtype=dtype),
        Chain([Shift(1e-5), Softplus()]),
        name=name + '_scale')
    return JointDistributionSequential([
        Independent(Normal(loc=make_loc(kernel_shape, kernel_initializer,
                                        'posterior_kernel'),
                           scale=make_scale(kernel_shape, 'posterior_kernel')),
                    reinterpreted_batch_ndims=prefer_static.size(kernel_shape),
                    name='posterior_kernel'),
        Independent(Normal(loc=make_loc(bias_shape, bias_initializer,
                                        'posterior_bias'),
                           scale=make_scale(bias_shape, 'posterior_bias')),
                    reinterpreted_batch_ndims=prefer_static.size(bias_shape),
                    name='posterior_bias'),
    ])
Esempio n. 2
0
def make_kernel_bias_posterior_mvn_diag(
        kernel_shape,
        bias_shape,
        kernel_initializer=None,
        bias_initializer=None,
        kernel_batch_ndims=0,  # pylint: disable=unused-argument
        bias_batch_ndims=0,  # pylint: disable=unused-argument
        dtype=tf.float32,
        kernel_name='posterior_kernel',
        bias_name='posterior_bias'):
    """Create learnable posterior for Variational layers with kernel and bias.

  Args:
    kernel_shape: ...
    bias_shape: ...
    kernel_initializer: ...
      Default value: `None` (i.e., `tf.initializers.glorot_uniform()`).
    bias_initializer: ...
      Default value: `None` (i.e., `tf.zeros`).
    kernel_batch_ndims: ...
      Default value: `0`.
    bias_batch_ndims: ...
      Default value: `0`.
    dtype: ...
      Default value: `tf.float32`.
    kernel_name: ...
      Default value: `"posterior_kernel"`.
    bias_name: ...
      Default value: `"posterior_bias"`.

  Returns:
    kernel_and_bias_distribution: ...
  """
    if kernel_initializer is None:
        kernel_initializer = nn_init_lib.glorot_uniform()
    if bias_initializer is None:
        bias_initializer = tf.zeros
    make_loc = lambda init_fn, shape, batch_ndims, name: tf.Variable(  # pylint: disable=g-long-lambda
        _try_call_init_fn(init_fn, shape, dtype, batch_ndims),
        name=name + '_loc')
    # Setting the initial scale to a relatively small value causes the `loc` to
    # quickly move toward a lower loss value.
    make_scale = lambda shape, name: TransformedVariable(  # pylint: disable=g-long-lambda
        tf.fill(shape, value=tf.constant(1e-3, dtype=dtype)),
        Chain([Shift(1e-5), Softplus()]),
        name=name + '_scale')
    return JointDistributionSequential([
        Independent(Normal(loc=make_loc(kernel_initializer, kernel_shape,
                                        kernel_batch_ndims, kernel_name),
                           scale=make_scale(kernel_shape, kernel_name)),
                    reinterpreted_batch_ndims=prefer_static.size(kernel_shape),
                    name=kernel_name),
        Independent(Normal(loc=make_loc(bias_initializer, bias_shape,
                                        kernel_batch_ndims, bias_name),
                           scale=make_scale(bias_shape, bias_name)),
                    reinterpreted_batch_ndims=prefer_static.size(bias_shape),
                    name=bias_name),
    ])
Esempio n. 3
0
def make_kernel_bias_posterior_mvn_diag(kernel_shape,
                                        bias_shape,
                                        dtype=tf.float32,
                                        kernel_initializer=None,
                                        bias_initializer=None,
                                        kernel_name='posterior_kernel',
                                        bias_name='posterior_bias'):
    """Create learnable posterior for Variational layers with kernel and bias.

  Args:
    kernel_shape: ...
    bias_shape: ...
    dtype: ...
      Default value: `tf.float32`.
    kernel_initializer: ...
      Default value: `None` (i.e., `tf.initializers.glorot_uniform()`).
    bias_initializer: ...
      Default value: `None` (i.e., `tf.zeros`).
    kernel_name: ...
      Default value: `"posterior_kernel"`.
    bias_name: ...
      Default value: `"posterior_bias"`.

  Returns:
    kernel_and_bias_distribution: ...
  """
    if kernel_initializer is None:
        kernel_initializer = tf.initializers.glorot_uniform()
    if bias_initializer is None:
        bias_initializer = tf.zeros
    make_loc = lambda shape, init, name: tf.Variable(  # pylint: disable=g-long-lambda
        init(shape, dtype=dtype),
        name=name + '_loc')
    make_scale = lambda shape, name: TransformedVariable(  # pylint: disable=g-long-lambda
        tf.ones(shape, dtype=dtype),
        Chain([Shift(1e-5), Softplus()]),
        name=name + '_scale')
    return JointDistributionSequential([
        Independent(Normal(loc=make_loc(kernel_shape, kernel_initializer,
                                        kernel_name),
                           scale=make_scale(kernel_shape, kernel_name)),
                    reinterpreted_batch_ndims=prefer_static.size(kernel_shape),
                    name=kernel_name),
        Independent(Normal(loc=make_loc(bias_shape, bias_initializer,
                                        bias_name),
                           scale=make_scale(bias_shape, bias_name)),
                    reinterpreted_batch_ndims=prefer_static.size(bias_shape),
                    name=bias_name),
    ])