def __call__(self, x): x = tf.convert_to_tensor(x, dtype=self.dtype, name='x') self._posterior_value = self.posterior_value_fn(self.posterior, seed=self._seed()) # pylint: disable=not-callable kernel_dist, bias_dist = self.unpack_weights_fn( # pylint: disable=not-callable self.posterior.sample_distributions(value=self.posterior_value)[0]) kernel_loc, kernel_scale = vi_lib.get_spherical_normal_loc_scale( kernel_dist) loc = tf.matmul(x, kernel_loc) scale = tf.sqrt(tf.matmul(tf.square(x), tf.square(kernel_scale))) _, sampled_bias = self.unpack_weights_fn(self.posterior_value) # pylint: disable=not-callable if sampled_bias is not None: try: bias_loc, bias_scale = vi_lib.get_spherical_normal_loc_scale( bias_dist) is_bias_spherical_normal = True except TypeError: is_bias_spherical_normal = False if is_bias_spherical_normal: loc = loc + bias_loc scale = tf.sqrt(tf.square(scale) + tf.square(bias_scale)) else: loc = loc + sampled_bias y = normal_lib.Normal(loc=loc, scale=scale).sample(seed=self._seed()) if self.activation_fn is not None: y = self.activation_fn(y) return y
def _eval(self, x, weights): kernel_dist, bias_dist = self.unpack_weights_fn( # pylint: disable=not-callable self.posterior.sample_distributions(value=weights)[0]) kernel_loc, kernel_scale = vi_lib.get_spherical_normal_loc_scale( kernel_dist) loc = tf.matmul(x, kernel_loc) scale = tf.sqrt(tf.matmul(tf.square(x), tf.square(kernel_scale))) _, sampled_bias = self.unpack_weights_fn(weights) # pylint: disable=not-callable if sampled_bias is not None: try: bias_loc, bias_scale = vi_lib.get_spherical_normal_loc_scale( bias_dist) is_bias_spherical_normal = True except TypeError: is_bias_spherical_normal = False if is_bias_spherical_normal: loc = loc + bias_loc scale = tf.sqrt(tf.square(scale) + tf.square(bias_scale)) else: loc = loc + sampled_bias y = normal_lib.Normal(loc=loc, scale=scale).sample(seed=self._seed()) return y