def __call__(self, x):
     x = tf.convert_to_tensor(x, dtype=self.dtype, name='x')
     self._posterior_value = self.posterior_value_fn(self.posterior,
                                                     seed=self._seed())  # pylint: disable=not-callable
     kernel_dist, bias_dist = self.unpack_weights_fn(  # pylint: disable=not-callable
         self.posterior.sample_distributions(value=self.posterior_value)[0])
     kernel_loc, kernel_scale = vi_lib.get_spherical_normal_loc_scale(
         kernel_dist)
     loc = tf.matmul(x, kernel_loc)
     scale = tf.sqrt(tf.matmul(tf.square(x), tf.square(kernel_scale)))
     _, sampled_bias = self.unpack_weights_fn(self.posterior_value)  # pylint: disable=not-callable
     if sampled_bias is not None:
         try:
             bias_loc, bias_scale = vi_lib.get_spherical_normal_loc_scale(
                 bias_dist)
             is_bias_spherical_normal = True
         except TypeError:
             is_bias_spherical_normal = False
         if is_bias_spherical_normal:
             loc = loc + bias_loc
             scale = tf.sqrt(tf.square(scale) + tf.square(bias_scale))
         else:
             loc = loc + sampled_bias
     y = normal_lib.Normal(loc=loc, scale=scale).sample(seed=self._seed())
     if self.activation_fn is not None:
         y = self.activation_fn(y)
     return y
Exemplo n.º 2
0
 def _eval(self, x, weights):
   kernel_dist, bias_dist = self.unpack_weights_fn(  # pylint: disable=not-callable
       self.posterior.sample_distributions(value=weights)[0])
   kernel_loc, kernel_scale = vi_lib.get_spherical_normal_loc_scale(
       kernel_dist)
   loc = tf.matmul(x, kernel_loc)
   scale = tf.sqrt(tf.matmul(tf.square(x), tf.square(kernel_scale)))
   _, sampled_bias = self.unpack_weights_fn(weights)  # pylint: disable=not-callable
   if sampled_bias is not None:
     try:
       bias_loc, bias_scale = vi_lib.get_spherical_normal_loc_scale(
           bias_dist)
       is_bias_spherical_normal = True
     except TypeError:
       is_bias_spherical_normal = False
     if is_bias_spherical_normal:
       loc = loc + bias_loc
       scale = tf.sqrt(tf.square(scale) + tf.square(bias_scale))
     else:
       loc = loc + sampled_bias
   y = normal_lib.Normal(loc=loc, scale=scale).sample(seed=self._seed())
   return y