def call(self, inputs, training=None): if not isinstance(self.kernel, random_variable.RandomVariable): return super(Conv2DVariationalDropout, self).call(inputs) self.call_weights() if training is None: training = tf.keras.backend.learning_phase() if self._convolution_op is None: padding = self.padding if self.padding == 'causal': padding = 'valid' if not isinstance(padding, (list, tuple)): padding = padding.upper() self._convolution_op = functools.partial( tf.nn.convolution, strides=self.strides, padding=padding, data_format='NHWC' if self.data_format == 'channels_last' else 'NCHW', dilations=self.dilation_rate) def dropped_inputs(): """Forward pass with dropout.""" # Clip magnitude of dropout rate, where we get the dropout rate alpha from # the additive parameterization (Molchanov et al., 2017): for weight ~ # Normal(mu, sigma**2), the variance `sigma**2 = alpha * mu**2`. mean = self.kernel.distribution.mean() log_variance = tf.math.log(self.kernel.distribution.variance()) log_alpha = log_variance - tf.math.log(tf.square(mean) + tf.keras.backend.epsilon()) log_alpha = tf.clip_by_value(log_alpha, -8., 8.) log_variance = log_alpha + tf.math.log(tf.square(mean) + tf.keras.backend.epsilon()) means = self._convolution_op(inputs, mean) stddevs = tf.sqrt( self._convolution_op(tf.square(inputs), tf.exp(log_variance)) + tf.keras.backend.epsilon()) if self.use_bias: if self.data_format == 'channels_first': means = tf.nn.bias_add(means, self.bias, data_format='NCHW') else: means = tf.nn.bias_add(means, self.bias, data_format='NHWC') outputs = generated_random_variables.Normal(loc=means, scale=stddevs) if self.activation is not None: outputs = self.activation(outputs) return outputs # Following tf.keras.Dropout, only apply variational dropout if training # flag is True. training_value = utils.smart_constant_value(training) if training_value is not None: if training_value: return dropped_inputs() else: return super(Conv2DVariationalDropout, self).call(inputs) return tf.cond( pred=training, true_fn=dropped_inputs, false_fn=lambda: super(Conv2DVariationalDropout, self).call(inputs))
def call(self, inputs, training=None): if not isinstance(self.kernel, random_variable.RandomVariable): return super(DenseVariationalDropout, self).call(inputs) self.call_weights() if training is None: training = tf.keras.backend.learning_phase() def dropped_inputs(): """Forward pass with dropout.""" # Clip magnitude of dropout rate, where we get the dropout rate alpha from # the additive parameterization (Molchanov et al., 2017): for weight ~ # Normal(mu, sigma**2), the variance `sigma**2 = alpha * mu**2`. mean = self.kernel.distribution.mean() log_variance = tf.math.log(self.kernel.distribution.variance()) log_alpha = log_variance - tf.math.log( tf.square(mean) + tf.keras.backend.epsilon()) log_alpha = tf.clip_by_value(log_alpha, -8., 8.) log_variance = log_alpha + tf.math.log( tf.square(mean) + tf.keras.backend.epsilon()) if inputs.shape.ndims <= 2: means = tf.matmul(inputs, mean) stddevs = tf.sqrt( tf.matmul(tf.square(inputs), tf.exp(log_variance)) + tf.keras.backend.epsilon()) else: means = tf.tensordot(inputs, mean, [[-1], [0]]) stddevs = tf.sqrt( tf.tensordot(tf.square(inputs), tf.exp(log_variance), [[-1], [0]]) + tf.keras.backend.epsilon()) if self.use_bias: means = tf.nn.bias_add(means, self.bias) outputs = generated_random_variables.Normal(loc=means, scale=stddevs) if self.activation is not None: outputs = self.activation(outputs) return outputs # Following tf.keras.Dropout, only apply variational dropout if training # flag is True. training_value = utils.smart_constant_value(training) if training_value is not None: if training_value: return dropped_inputs() else: return super(DenseVariationalDropout, self).call(inputs) return tf.cond( pred=training, true_fn=dropped_inputs, false_fn=lambda: super(DenseVariationalDropout, self).call(inputs))