def make_kernel_bias_posterior_mvn_diag(kernel_shape, bias_shape, dtype=tf.float32, kernel_initializer=None, bias_initializer=None): """Create learnable posterior for Variational layers with kernel and bias.""" if kernel_initializer is None: kernel_initializer = tf.initializers.glorot_normal() if bias_initializer is None: bias_initializer = tf.initializers.glorot_normal() make_loc = lambda shape, init, name: tf.Variable( # pylint: disable=g-long-lambda init(shape, dtype=dtype), name=name + '_loc') make_scale = lambda shape, name: TransformedVariable( # pylint: disable=g-long-lambda tf.ones(shape, dtype=dtype), Chain([Shift(1e-5), Softplus()]), name=name + '_scale') return JointDistributionSequential([ Independent(Normal(loc=make_loc(kernel_shape, kernel_initializer, 'posterior_kernel'), scale=make_scale(kernel_shape, 'posterior_kernel')), reinterpreted_batch_ndims=prefer_static.size(kernel_shape), name='posterior_kernel'), Independent(Normal(loc=make_loc(bias_shape, bias_initializer, 'posterior_bias'), scale=make_scale(bias_shape, 'posterior_bias')), reinterpreted_batch_ndims=prefer_static.size(bias_shape), name='posterior_bias'), ])
def __init__(self, nchan, dtype=tf.float32, validate_args=False, name=None): parameters = dict(locals()) self._initialized = tf.Variable(False, trainable=False) self._m = tf.Variable(tf.zeros(nchan, dtype)) self._s = TransformedVariable(tf.ones(nchan, dtype), exp.Exp()) self._bijector = invert.Invert( chain.Chain([ scale.Scale(self._s), shift.Shift(self._m), ])) super(ActivationNormalization, self).__init__( validate_args=validate_args, forward_min_event_ndims=1, parameters=parameters, name=name or 'ActivationNormalization')
def make_kernel_bias_posterior_mvn_diag( kernel_shape, bias_shape, kernel_initializer=None, bias_initializer=None, kernel_batch_ndims=0, # pylint: disable=unused-argument bias_batch_ndims=0, # pylint: disable=unused-argument dtype=tf.float32, kernel_name='posterior_kernel', bias_name='posterior_bias'): """Create learnable posterior for Variational layers with kernel and bias. Args: kernel_shape: ... bias_shape: ... kernel_initializer: ... Default value: `None` (i.e., `tf.initializers.glorot_uniform()`). bias_initializer: ... Default value: `None` (i.e., `tf.zeros`). kernel_batch_ndims: ... Default value: `0`. bias_batch_ndims: ... Default value: `0`. dtype: ... Default value: `tf.float32`. kernel_name: ... Default value: `"posterior_kernel"`. bias_name: ... Default value: `"posterior_bias"`. Returns: kernel_and_bias_distribution: ... """ if kernel_initializer is None: kernel_initializer = nn_init_lib.glorot_uniform() if bias_initializer is None: bias_initializer = tf.zeros make_loc = lambda init_fn, shape, batch_ndims, name: tf.Variable( # pylint: disable=g-long-lambda _try_call_init_fn(init_fn, shape, dtype, batch_ndims), name=name + '_loc') # Setting the initial scale to a relatively small value causes the `loc` to # quickly move toward a lower loss value. make_scale = lambda shape, name: TransformedVariable( # pylint: disable=g-long-lambda tf.fill(shape, value=tf.constant(1e-3, dtype=dtype)), Chain([Shift(1e-5), Softplus()]), name=name + '_scale') return JointDistributionSequential([ Independent(Normal(loc=make_loc(kernel_initializer, kernel_shape, kernel_batch_ndims, kernel_name), scale=make_scale(kernel_shape, kernel_name)), reinterpreted_batch_ndims=prefer_static.size(kernel_shape), name=kernel_name), Independent(Normal(loc=make_loc(bias_initializer, bias_shape, kernel_batch_ndims, bias_name), scale=make_scale(bias_shape, bias_name)), reinterpreted_batch_ndims=prefer_static.size(bias_shape), name=bias_name), ])
class ActivationNormalization(bijector.Bijector): """Bijector to implement Activation Normalization (ActNorm).""" def __init__(self, nchan, dtype=tf.float32, validate_args=False, name=None): parameters = dict(locals()) self._initialized = tf.Variable(False, trainable=False) self._m = tf.Variable(tf.zeros(nchan, dtype)) self._s = TransformedVariable(tf.ones(nchan, dtype), exp.Exp()) self._bijector = invert.Invert( chain.Chain([ scale.Scale(self._s), shift.Shift(self._m), ])) super(ActivationNormalization, self).__init__(validate_args=validate_args, forward_min_event_ndims=1, parameters=parameters, name=name or 'ActivationNormalization') def _inverse(self, y, **kwargs): with tf.control_dependencies([self._maybe_init(y, inverse=True)]): return self._bijector.inverse(y, **kwargs) def _forward(self, x, **kwargs): with tf.control_dependencies([self._maybe_init(x, inverse=False)]): return self._bijector.forward(x, **kwargs) def _inverse_log_det_jacobian(self, y, **kwargs): with tf.control_dependencies([self._maybe_init(y, inverse=True)]): return self._bijector.inverse_log_det_jacobian(y, 1, **kwargs) def _forward_log_det_jacobian(self, x, **kwargs): with tf.control_dependencies([self._maybe_init(x, inverse=False)]): return self._bijector.forward_log_det_jacobian(x, 1, **kwargs) def _maybe_init(self, inputs, inverse): """Initialize if not already initialized.""" def _init(): """Build the data-dependent initialization.""" axis = prefer_static.range(prefer_static.rank(inputs) - 1) m = tf.math.reduce_mean(inputs, axis=axis) s = (tf.math.reduce_std(inputs, axis=axis) + 10. * np.finfo(dtype_util.as_numpy_dtype(inputs.dtype)).eps) if inverse: s = 1 / s m = -m else: m = m / s with tf.control_dependencies( [self._m.assign(m), self._s.assign(s)]): return self._initialized.assign(True) return tf.cond(self._initialized, tf.no_op, _init)
def make_kernel_bias_posterior_mvn_diag(kernel_shape, bias_shape, dtype=tf.float32, kernel_initializer=None, bias_initializer=None, kernel_name='posterior_kernel', bias_name='posterior_bias'): """Create learnable posterior for Variational layers with kernel and bias. Args: kernel_shape: ... bias_shape: ... dtype: ... Default value: `tf.float32`. kernel_initializer: ... Default value: `None` (i.e., `tf.initializers.glorot_uniform()`). bias_initializer: ... Default value: `None` (i.e., `tf.zeros`). kernel_name: ... Default value: `"posterior_kernel"`. bias_name: ... Default value: `"posterior_bias"`. Returns: kernel_and_bias_distribution: ... """ if kernel_initializer is None: kernel_initializer = tf.initializers.glorot_uniform() if bias_initializer is None: bias_initializer = tf.zeros make_loc = lambda shape, init, name: tf.Variable( # pylint: disable=g-long-lambda init(shape, dtype=dtype), name=name + '_loc') make_scale = lambda shape, name: TransformedVariable( # pylint: disable=g-long-lambda tf.ones(shape, dtype=dtype), Chain([Shift(1e-5), Softplus()]), name=name + '_scale') return JointDistributionSequential([ Independent(Normal(loc=make_loc(kernel_shape, kernel_initializer, kernel_name), scale=make_scale(kernel_shape, kernel_name)), reinterpreted_batch_ndims=prefer_static.size(kernel_shape), name=kernel_name), Independent(Normal(loc=make_loc(bias_shape, bias_initializer, bias_name), scale=make_scale(bias_shape, bias_name)), reinterpreted_batch_ndims=prefer_static.size(bias_shape), name=bias_name), ])