def call(self, inputs): if self.conditional_inputs is None and self.conditional_outputs is None: covariance_matrix = self.covariance_fn(inputs, inputs) # Tile locations so output has shape [units, batch_size]. Covariance will # broadcast to [units, batch_size, batch_size], and we perform # shape manipulations to get a random variable over [batch_size, units]. loc = self.mean_fn(inputs) loc = tf.tile(loc[tf.newaxis], [self.units] + [1] * len(loc.shape)) else: knn = self.covariance_fn(inputs, inputs) knm = self.covariance_fn(inputs, self.conditional_inputs) kmm = self.covariance_fn(self.conditional_inputs, self.conditional_inputs) kmm = tf.matrix_set_diag( kmm, tf.matrix_diag_part(kmm) + tf.keras.backend.epsilon()) kmm_tril = tf.linalg.cholesky(kmm) kmm_tril_operator = tf.linalg.LinearOperatorLowerTriangular(kmm_tril) knm_operator = tf.linalg.LinearOperatorFullMatrix(knm) # TODO(trandustin): Vectorize linear algebra for multiple outputs. For # now, we do each separately and stack to obtain a locations Tensor of # shape [units, batch_size]. loc = [] for conditional_outputs_unit in tf.unstack(self.conditional_outputs, axis=-1): center = conditional_outputs_unit - self.mean_fn( self.conditional_inputs) loc_unit = knm_operator.matvec( kmm_tril_operator.solvevec(kmm_tril_operator.solvevec(center), adjoint=True)) loc.append(loc_unit) loc = tf.stack(loc) + self.mean_fn(inputs)[tf.newaxis] covariance_matrix = knn covariance_matrix -= knm_operator.matmul( kmm_tril_operator.solve( kmm_tril_operator.solve(knm, adjoint_arg=True), adjoint=True)) covariance_matrix = tf.matrix_set_diag( covariance_matrix, tf.matrix_diag_part(covariance_matrix) + tf.keras.backend.epsilon()) # Form a multivariate normal random variable with batch_shape units and # event_shape batch_size. Then make it be independent across the units # dimension. Then transpose its dimensions so it is [batch_size, units]. random_variable = ed.MultivariateNormalFullCovariance( loc=loc, covariance_matrix=covariance_matrix) random_variable = ed.Independent(random_variable.distribution, reinterpreted_batch_ndims=1) bijector = tfp.bijectors.Inline( forward_fn=lambda x: tf.transpose(x, [1, 0]), inverse_fn=lambda y: tf.transpose(y, [1, 0]), forward_event_shape_fn=lambda input_shape: input_shape[::-1], forward_event_shape_tensor_fn=lambda input_shape: input_shape[::-1], inverse_log_det_jacobian_fn=lambda y: tf.cast(0, y.dtype), forward_min_event_ndims=2) random_variable = ed.TransformedDistribution(random_variable.distribution, bijector=bijector) return random_variable
def make_mfvi_sgp_mixture_family(n_mixture, N, gp_dist, name, use_logistic_link=False): """Makes mixture of MFVI and Sparse GP variational prior Args: n_mixture: (int) Number of MFVI mixture. N: (int) Number of sample observations. gp_dist: (tfd.Distribution) variational family for gaussian process. name: (str) Name prefix of parameters Returns: mfvi_mix_dist: (tfd.Distribution) Mixture distribution. mixture_logits_mfvi_mix: (tf.Variable or None) Mixture probability (logit) for MFVI families. If n_mixture=1 then None. qf_mean_mfvi_mix, qf_sdev_mfvi_mix (tf.Variable) Mean and sdev for MFVI families. Shape (n_mixture, Nx) if n_mixture > 1, and shape (Nx, ) if n_mixture = 1. """ # define mixture probability mixture_logits = tf.get_variable(name="{}_mixture_logits".format(name), shape=[2]) (mfvi_mix_dist, mixture_logits_mfvi_mix, qf_mean_mfvi_mix, qf_sdev_mfvi_mix) = make_mfvi_mixture_family(n_mixture=n_mixture, N=N, name=name) mixture_par_list = [ mixture_logits, mixture_logits_mfvi_mix, qf_mean_mfvi_mix, qf_sdev_mfvi_mix ] if use_logistic_link: mfvi_sgp_mix_dist = ed.TransformedDistribution( tfd.Mixture(cat=tfd.Categorical(logits=mixture_logits), components=[mfvi_mix_dist, gp_dist]), bijector=tfp.bijectors.Sigmoid(), name=name) else: mfvi_sgp_mix_dist = ed.Mixture( cat=tfd.Categorical(logits=mixture_logits), components=[mfvi_mix_dist, gp_dist], name=name) return mfvi_sgp_mix_dist, mixture_par_list
def german_credit_model(): x_numeric = tf.constant(numericals.astype(np.float32)) x_categorical = [tf.one_hot(c, c.max() + 1) for c in categoricals] all_x = tf.concat([x_numeric] + x_categorical, 1) num_features = int(all_x.shape[1]) overall_log_scale = ed.Normal(loc=0., scale=10., name='overall_log_scale') beta_log_scales = ed.TransformedDistribution( tfd.Gamma(0.5 * tf.ones([num_features]), 0.5), bijector=tfb.Invert(tfb.Exp()), name='beta_log_scales') beta = ed.Normal(loc=tf.zeros([num_features]), scale=tf.exp(overall_log_scale + beta_log_scales), name='beta') logits = tf.einsum('nd,md->mn', all_x, beta[tf.newaxis, :]) return ed.Bernoulli(logits=logits, name='y')
def __call__(self, inputs, *args, **kwargs): if not self.built: mean, variance = tf.nn.moments( inputs, axes=[i for i in range(inputs.shape.ndims - 1)]) self.bias_initial_value = -mean # TODO(trandustin): Optionally, actnorm multiplies log_scale by a fixed # log_scale factor (e.g., 3.) and initializes by # initial_value / log_scale_factor. self.log_scale_initial_value = tf.log( 1. / (tf.sqrt(variance) + self.epsilon)) if not isinstance(inputs, ed.RandomVariable): return super(ActNorm, self).__call__(inputs, *args, **kwargs) bijector = tfp.bijectors.Inline( forward_fn=self.__call__, inverse_fn=self.reverse, inverse_log_det_jacobian_fn=lambda y: -self.log_det_jacobian(y), forward_min_event_ndims=0) return ed.TransformedDistribution(inputs.distribution, bijector=bijector)