def call(self, inputs): if self.conditional_inputs is None and self.conditional_outputs is None: covariance_matrix = self.covariance_fn(inputs, inputs) # Tile locations so output has shape [units, batch_size]. Covariance will # broadcast to [units, batch_size, batch_size], and we perform # shape manipulations to get a random variable over [batch_size, units]. loc = self.mean_fn(inputs) loc = tf.tile(loc[tf.newaxis], [self.units] + [1] * len(loc.shape)) else: knn = self.covariance_fn(inputs, inputs) knm = self.covariance_fn(inputs, self.conditional_inputs) kmm = self.covariance_fn(self.conditional_inputs, self.conditional_inputs) kmm = tf.linalg.set_diag( kmm, tf.linalg.diag_part(kmm) + tf.keras.backend.epsilon()) kmm_tril = tf.linalg.cholesky(kmm) kmm_tril_operator = tf.linalg.LinearOperatorLowerTriangular(kmm_tril) knm_operator = tf.linalg.LinearOperatorFullMatrix(knm) # TODO(trandustin): Vectorize linear algebra for multiple outputs. For # now, we do each separately and stack to obtain a locations Tensor of # shape [units, batch_size]. loc = [] for conditional_outputs_unit in tf.unstack(self.conditional_outputs, axis=-1): center = conditional_outputs_unit - self.mean_fn( self.conditional_inputs) loc_unit = knm_operator.matvec( kmm_tril_operator.solvevec(kmm_tril_operator.solvevec(center), adjoint=True)) loc.append(loc_unit) loc = tf.stack(loc) + self.mean_fn(inputs)[tf.newaxis] covariance_matrix = knn covariance_matrix -= knm_operator.matmul( kmm_tril_operator.solve( kmm_tril_operator.solve(knm, adjoint_arg=True), adjoint=True)) covariance_matrix = tf.linalg.set_diag( covariance_matrix, tf.linalg.diag_part(covariance_matrix) + tf.keras.backend.epsilon()) # Form a multivariate normal random variable with batch_shape units and # event_shape batch_size. Then make it be independent across the units # dimension. Then transpose its dimensions so it is [batch_size, units]. random_variable = ( generated_random_variables.MultivariateNormalFullCovariance( loc=loc, covariance_matrix=covariance_matrix)) random_variable = generated_random_variables.Independent( random_variable.distribution, reinterpreted_batch_ndims=1) bijector = tfp.bijectors.Inline( forward_fn=lambda x: tf.transpose(x, perm=[1, 0]), inverse_fn=lambda y: tf.transpose(y, perm=[1, 0]), forward_event_shape_fn=lambda input_shape: input_shape[::-1], forward_event_shape_tensor_fn=lambda input_shape: input_shape[::-1], inverse_log_det_jacobian_fn=lambda y: tf.cast(0, y.dtype), forward_min_event_ndims=2) random_variable = generated_random_variables.TransformedDistribution( random_variable.distribution, bijector=bijector) return random_variable
def call(self, inputs): if (not isinstance(inputs, random_variable.RandomVariable) and not isinstance(self.kernel, random_variable.RandomVariable) and not isinstance(self.bias, random_variable.RandomVariable)): return super(DenseDVI, self).call(inputs) self.call_weights() inputs_mean, inputs_variance, inputs_covariance = get_moments(inputs) kernel_mean, kernel_variance, _ = get_moments(self.kernel) if self.use_bias: bias_mean, _, bias_covariance = get_moments(self.bias) # E[outputs] = E[inputs] * E[kernel] + E[bias] mean = tf.tensordot(inputs_mean, kernel_mean, [[-1], [0]]) if self.use_bias: mean = tf.nn.bias_add(mean, bias_mean) # Cov = E[inputs**2] Cov(kernel) + E[W]^T Cov(inputs) E[W] + Cov(bias) # For first term, assume Cov(kernel) = 0 on off-diagonals so we only # compute diagonal term. covariance_diag = tf.tensordot(inputs_variance + inputs_mean**2, kernel_variance, [[-1], [0]]) # Compute quadratic form E[W]^T Cov E[W] from right-to-left. First is # [..., features, features], [features, units] -> [..., features, units]. cov_w = tf.tensordot(inputs_covariance, kernel_mean, [[-1], [0]]) # Next is [..., features, units], [features, units] -> [..., units, units]. w_cov_w = tf.tensordot(cov_w, kernel_mean, [[-2], [0]]) covariance = w_cov_w if self.use_bias: covariance += bias_covariance covariance = tf.linalg.set_diag( covariance, tf.linalg.diag_part(covariance) + covariance_diag) if self.activation in (tf.keras.activations.relu, tf.nn.relu): # Compute activation's moments with variable names from Wu et al. (2018). variance = tf.linalg.diag_part(covariance) scale = tf.sqrt(variance) mu = mean / (scale + tf.keras.backend.epsilon()) mean = scale * soft_relu(mu) pairwise_variances = (tf.expand_dims(variance, -1) * tf.expand_dims(variance, -2) ) # [..., units, units] rho = covariance / tf.sqrt(pairwise_variances + tf.keras.backend.epsilon()) rho = tf.clip_by_value(rho, -1. / (1. + tf.keras.backend.epsilon()), 1. / (1. + tf.keras.backend.epsilon())) s = covariance / (rho + tf.keras.backend.epsilon()) mu1 = tf.expand_dims(mu, -1) # [..., units, 1] mu2 = tf.linalg.matrix_transpose(mu1) # [..., 1, units] a = (soft_relu(mu1) * soft_relu(mu2) + rho * tfp.distributions.Normal(0., 1.).cdf(mu1) * tfp.distributions.Normal(0., 1.).cdf(mu2)) gh = tf.asinh(rho) bar_rho = tf.sqrt(1. - rho**2) gr = gh + rho / (1. + bar_rho) # Include numerically stable versions of gr and rho when multiplying or # dividing them. The sign of gr*rho and rho/gr is always positive. safe_gr = tf.abs(gr) + 0.5 * tf.keras.backend.epsilon() safe_rho = tf.abs(rho) + tf.keras.backend.epsilon() exp_negative_q = gr / ( 2. * math.pi) * tf.exp(-safe_rho / (2. * safe_gr * (1 + bar_rho)) + (gh - rho) / (safe_gr * safe_rho) * mu1 * mu2) covariance = s * (a + exp_negative_q) elif self.activation not in (tf.keras.activations.linear, None): raise NotImplementedError( 'Activation is {}. Deterministic variational ' 'inference is only available if activation is ' 'ReLU or None.'.format(self.activation)) return generated_random_variables.MultivariateNormalFullCovariance( mean, covariance)