Example #1
0
  def call(self, inputs, training=None):
    if not isinstance(self.kernel, random_variable.RandomVariable):
      return super(Conv2DVariationalDropout, self).call(inputs)
    self.call_weights()
    if training is None:
      training = tf.keras.backend.learning_phase()
    if self._convolution_op is None:
      padding = self.padding
      if self.padding == 'causal':
        padding = 'valid'
      if not isinstance(padding, (list, tuple)):
        padding = padding.upper()
      self._convolution_op = functools.partial(
          tf.nn.convolution,
          strides=self.strides,
          padding=padding,
          data_format='NHWC' if self.data_format == 'channels_last' else 'NCHW',
          dilations=self.dilation_rate)

    def dropped_inputs():
      """Forward pass with dropout."""
      # Clip magnitude of dropout rate, where we get the dropout rate alpha from
      # the additive parameterization (Molchanov et al., 2017): for weight ~
      # Normal(mu, sigma**2), the variance `sigma**2 = alpha * mu**2`.
      mean = self.kernel.distribution.mean()
      log_variance = tf.math.log(self.kernel.distribution.variance())
      log_alpha = log_variance - tf.math.log(tf.square(mean) +
                                             tf.keras.backend.epsilon())
      log_alpha = tf.clip_by_value(log_alpha, -8., 8.)
      log_variance = log_alpha + tf.math.log(tf.square(mean) +
                                             tf.keras.backend.epsilon())

      means = self._convolution_op(inputs, mean)
      stddevs = tf.sqrt(
          self._convolution_op(tf.square(inputs), tf.exp(log_variance)) +
          tf.keras.backend.epsilon())
      if self.use_bias:
        if self.data_format == 'channels_first':
          means = tf.nn.bias_add(means, self.bias, data_format='NCHW')
        else:
          means = tf.nn.bias_add(means, self.bias, data_format='NHWC')
      outputs = generated_random_variables.Normal(loc=means, scale=stddevs)
      if self.activation is not None:
        outputs = self.activation(outputs)
      return outputs

    # Following tf.keras.Dropout, only apply variational dropout if training
    # flag is True.
    training_value = utils.smart_constant_value(training)
    if training_value is not None:
      if training_value:
        return dropped_inputs()
      else:
        return super(Conv2DVariationalDropout, self).call(inputs)
    return tf.cond(
        pred=training,
        true_fn=dropped_inputs,
        false_fn=lambda: super(Conv2DVariationalDropout, self).call(inputs))
Example #2
0
    def call(self, inputs, training=None):
        if not isinstance(self.kernel, random_variable.RandomVariable):
            return super(DenseVariationalDropout, self).call(inputs)
        self.call_weights()
        if training is None:
            training = tf.keras.backend.learning_phase()

        def dropped_inputs():
            """Forward pass with dropout."""
            # Clip magnitude of dropout rate, where we get the dropout rate alpha from
            # the additive parameterization (Molchanov et al., 2017): for weight ~
            # Normal(mu, sigma**2), the variance `sigma**2 = alpha * mu**2`.
            mean = self.kernel.distribution.mean()
            log_variance = tf.math.log(self.kernel.distribution.variance())
            log_alpha = log_variance - tf.math.log(
                tf.square(mean) + tf.keras.backend.epsilon())
            log_alpha = tf.clip_by_value(log_alpha, -8., 8.)
            log_variance = log_alpha + tf.math.log(
                tf.square(mean) + tf.keras.backend.epsilon())

            if inputs.shape.ndims <= 2:
                means = tf.matmul(inputs, mean)
                stddevs = tf.sqrt(
                    tf.matmul(tf.square(inputs), tf.exp(log_variance)) +
                    tf.keras.backend.epsilon())
            else:
                means = tf.tensordot(inputs, mean, [[-1], [0]])
                stddevs = tf.sqrt(
                    tf.tensordot(tf.square(inputs), tf.exp(log_variance),
                                 [[-1], [0]]) + tf.keras.backend.epsilon())
            if self.use_bias:
                means = tf.nn.bias_add(means, self.bias)
            outputs = generated_random_variables.Normal(loc=means,
                                                        scale=stddevs)
            if self.activation is not None:
                outputs = self.activation(outputs)
            return outputs

        # Following tf.keras.Dropout, only apply variational dropout if training
        # flag is True.
        training_value = utils.smart_constant_value(training)
        if training_value is not None:
            if training_value:
                return dropped_inputs()
            else:
                return super(DenseVariationalDropout, self).call(inputs)
        return tf.cond(
            pred=training,
            true_fn=dropped_inputs,
            false_fn=lambda: super(DenseVariationalDropout, self).call(inputs))