Ejemplo n.º 1
0
  def call(self, inputs):
    if self.conditional_inputs is None and self.conditional_outputs is None:
      covariance_matrix = self.covariance_fn(inputs, inputs)
      # Tile locations so output has shape [units, batch_size]. Covariance will
      # broadcast to [units, batch_size, batch_size], and we perform
      # shape manipulations to get a random variable over [batch_size, units].
      loc = self.mean_fn(inputs)
      loc = tf.tile(loc[tf.newaxis], [self.units] + [1] * len(loc.shape))
    else:
      knn = self.covariance_fn(inputs, inputs)
      knm = self.covariance_fn(inputs, self.conditional_inputs)
      kmm = self.covariance_fn(self.conditional_inputs, self.conditional_inputs)
      kmm = tf.linalg.set_diag(
          kmm, tf.linalg.diag_part(kmm) + tf.keras.backend.epsilon())
      kmm_tril = tf.linalg.cholesky(kmm)
      kmm_tril_operator = tf.linalg.LinearOperatorLowerTriangular(kmm_tril)
      knm_operator = tf.linalg.LinearOperatorFullMatrix(knm)

      # TODO(trandustin): Vectorize linear algebra for multiple outputs. For
      # now, we do each separately and stack to obtain a locations Tensor of
      # shape [units, batch_size].
      loc = []
      for conditional_outputs_unit in tf.unstack(self.conditional_outputs,
                                                 axis=-1):
        center = conditional_outputs_unit - self.mean_fn(
            self.conditional_inputs)
        loc_unit = knm_operator.matvec(
            kmm_tril_operator.solvevec(kmm_tril_operator.solvevec(center),
                                       adjoint=True))
        loc.append(loc_unit)
      loc = tf.stack(loc) + self.mean_fn(inputs)[tf.newaxis]

      covariance_matrix = knn
      covariance_matrix -= knm_operator.matmul(
          kmm_tril_operator.solve(
              kmm_tril_operator.solve(knm, adjoint_arg=True), adjoint=True))

    covariance_matrix = tf.linalg.set_diag(
        covariance_matrix,
        tf.linalg.diag_part(covariance_matrix) + tf.keras.backend.epsilon())

    # Form a multivariate normal random variable with batch_shape units and
    # event_shape batch_size. Then make it be independent across the units
    # dimension. Then transpose its dimensions so it is [batch_size, units].
    random_variable = (
        generated_random_variables.MultivariateNormalFullCovariance(
            loc=loc, covariance_matrix=covariance_matrix))
    random_variable = generated_random_variables.Independent(
        random_variable.distribution, reinterpreted_batch_ndims=1)
    bijector = tfp.bijectors.Inline(
        forward_fn=lambda x: tf.transpose(x, perm=[1, 0]),
        inverse_fn=lambda y: tf.transpose(y, perm=[1, 0]),
        forward_event_shape_fn=lambda input_shape: input_shape[::-1],
        forward_event_shape_tensor_fn=lambda input_shape: input_shape[::-1],
        inverse_log_det_jacobian_fn=lambda y: tf.cast(0, y.dtype),
        forward_min_event_ndims=2)
    random_variable = generated_random_variables.TransformedDistribution(
        random_variable.distribution, bijector=bijector)
    return random_variable
Ejemplo n.º 2
0
    def call(self, inputs):
        if (not isinstance(inputs, random_variable.RandomVariable)
                and not isinstance(self.kernel, random_variable.RandomVariable)
                and not isinstance(self.bias, random_variable.RandomVariable)):
            return super(DenseDVI, self).call(inputs)
        self.call_weights()
        inputs_mean, inputs_variance, inputs_covariance = get_moments(inputs)
        kernel_mean, kernel_variance, _ = get_moments(self.kernel)
        if self.use_bias:
            bias_mean, _, bias_covariance = get_moments(self.bias)

        # E[outputs] = E[inputs] * E[kernel] + E[bias]
        mean = tf.tensordot(inputs_mean, kernel_mean, [[-1], [0]])
        if self.use_bias:
            mean = tf.nn.bias_add(mean, bias_mean)

        # Cov = E[inputs**2] Cov(kernel) + E[W]^T Cov(inputs) E[W] + Cov(bias)
        # For first term, assume Cov(kernel) = 0 on off-diagonals so we only
        # compute diagonal term.
        covariance_diag = tf.tensordot(inputs_variance + inputs_mean**2,
                                       kernel_variance, [[-1], [0]])
        # Compute quadratic form E[W]^T Cov E[W] from right-to-left. First is
        #  [..., features, features], [features, units] -> [..., features, units].
        cov_w = tf.tensordot(inputs_covariance, kernel_mean, [[-1], [0]])
        # Next is [..., features, units], [features, units] -> [..., units, units].
        w_cov_w = tf.tensordot(cov_w, kernel_mean, [[-2], [0]])
        covariance = w_cov_w
        if self.use_bias:
            covariance += bias_covariance
        covariance = tf.linalg.set_diag(
            covariance,
            tf.linalg.diag_part(covariance) + covariance_diag)

        if self.activation in (tf.keras.activations.relu, tf.nn.relu):
            # Compute activation's moments with variable names from Wu et al. (2018).
            variance = tf.linalg.diag_part(covariance)
            scale = tf.sqrt(variance)
            mu = mean / (scale + tf.keras.backend.epsilon())
            mean = scale * soft_relu(mu)

            pairwise_variances = (tf.expand_dims(variance, -1) *
                                  tf.expand_dims(variance, -2)
                                  )  # [..., units, units]
            rho = covariance / tf.sqrt(pairwise_variances +
                                       tf.keras.backend.epsilon())
            rho = tf.clip_by_value(rho,
                                   -1. / (1. + tf.keras.backend.epsilon()),
                                   1. / (1. + tf.keras.backend.epsilon()))
            s = covariance / (rho + tf.keras.backend.epsilon())
            mu1 = tf.expand_dims(mu, -1)  # [..., units, 1]
            mu2 = tf.linalg.matrix_transpose(mu1)  # [..., 1, units]
            a = (soft_relu(mu1) * soft_relu(mu2) +
                 rho * tfp.distributions.Normal(0., 1.).cdf(mu1) *
                 tfp.distributions.Normal(0., 1.).cdf(mu2))
            gh = tf.asinh(rho)
            bar_rho = tf.sqrt(1. - rho**2)
            gr = gh + rho / (1. + bar_rho)
            # Include numerically stable versions of gr and rho when multiplying or
            # dividing them. The sign of gr*rho and rho/gr is always positive.
            safe_gr = tf.abs(gr) + 0.5 * tf.keras.backend.epsilon()
            safe_rho = tf.abs(rho) + tf.keras.backend.epsilon()
            exp_negative_q = gr / (
                2. * math.pi) * tf.exp(-safe_rho / (2. * safe_gr *
                                                    (1 + bar_rho)) +
                                       (gh - rho) /
                                       (safe_gr * safe_rho) * mu1 * mu2)
            covariance = s * (a + exp_negative_q)
        elif self.activation not in (tf.keras.activations.linear, None):
            raise NotImplementedError(
                'Activation is {}. Deterministic variational '
                'inference is only available if activation is '
                'ReLU or None.'.format(self.activation))

        return generated_random_variables.MultivariateNormalFullCovariance(
            mean, covariance)