Exemple #1
0
 def create_custom_getter(fresh_noise_per_connection):
   bbb_getter = bbb.bayes_by_backprop_getter(
       posterior_builder=bbb.diagonal_gaussian_posterior_builder,
       prior_builder=bbb.fixed_gaussian_prior_builder,
       kl_builder=bbb.stochastic_kl_builder,
       sampling_mode_tensor=tf.constant(bbb.EstimatorModes.sample),
       fresh_noise_per_connection=fresh_noise_per_connection)
   return bbb_getter
Exemple #2
0
  def testWeightsResampledWithKeepControlDeps(self):
    """Test that weights are resampled with `keep_control_dependencies=True`.

    Test strategy: We test the inverse of `testRecurrentNetSamplesWeightsOnce`.
    Provide an input sequence x whose value is the same at each time step. If
    the outputs from f_theta() are the different at each time step, then theta
    is different at each time step. In principle, it is possible that different
    thetas give the same outputs, but this is very unlikely.
    """
    seq_length = 10
    batch_size = 1
    input_dim = 5
    output_dim = 5

    bbb_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=bbb.diagonal_gaussian_posterior_builder,
        prior_builder=bbb.fixed_gaussian_prior_builder,
        kl_builder=bbb.stochastic_kl_builder,
        sampling_mode_tensor=tf.constant(bbb.EstimatorModes.sample),
        keep_control_dependencies=True)

    class NoStateLSTM(snt.LSTM):
      """An LSTM which ignores hidden state."""

      def _build(self, inputs, state):
        outputs, _ = super(NoStateLSTM, self)._build(inputs, state)
        return outputs, state

    with tf.variable_scope("model", custom_getter=bbb_getter):
      core = NoStateLSTM(output_dim)

    input_seq = tf.ones(shape=(seq_length, batch_size, input_dim))
    output_seq, _ = tf.nn.dynamic_rnn(
        core,
        inputs=input_seq,
        initial_state=core.initial_state(batch_size=batch_size),
        time_major=True)

    init_op = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init_op)
      output_res_one = sess.run(output_seq)
      output_res_two = sess.run(output_seq)

    # Ensure that the sequence is different at every time step
    output_zero = output_res_one[0]
    for time_step_output in output_res_one[1:]:
      distance = np.linalg.norm(
          time_step_output.flatten() - output_zero.flatten())
      self.assertGreater(distance, 0.001)

    # Ensure that the noise is different in the second run by checking that
    # the output sequence is different now.
    for first_run_elem, second_run_elem in zip(output_res_one, output_res_two):
      distance = np.linalg.norm(
          first_run_elem.flatten() - second_run_elem.flatten())
      self.assertGreater(distance, 0.001)
Exemple #3
0
def build_modules(is_training, vocab_size):
    """Construct the modules used in the graph."""

    # Construct the custom getter which implements Bayes by Backprop.
    if is_training:
        estimator_mode = tf.constant(bbb.EstimatorModes.sample)
    else:
        estimator_mode = tf.constant(bbb.EstimatorModes.mean)
    lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=lstm_posterior_builder,
        prior_builder=custom_scale_mixture_prior_builder,
        kl_builder=bbb.stochastic_kl_builder,
        sampling_mode_tensor=estimator_mode,
    )
    non_lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=non_lstm_posterior_builder,
        prior_builder=custom_scale_mixture_prior_builder,
        kl_builder=bbb.stochastic_kl_builder,
        sampling_mode_tensor=estimator_mode,
    )

    embed_layer = snt.Embed(
        vocab_size=vocab_size,
        embed_dim=FLAGS.embedding_size,
        custom_getter=non_lstm_bbb_custom_getter,
        name="input_embedding",
    )

    cores = [
        snt.LSTM(
            FLAGS.hidden_size,
            custom_getter=lstm_bbb_custom_getter,
            forget_bias=0.0,
            name="lstm_layer_{}".format(i),
        ) for i in six.moves.range(FLAGS.n_layers)
    ]
    rnn_core = snt.DeepRNN(cores,
                           skip_connections=False,
                           name="deep_lstm_core")

    # Do BBB on weights but not biases of output layer.
    output_linear = snt.Linear(vocab_size,
                               custom_getter={"w": non_lstm_bbb_custom_getter})
    return embed_layer, rnn_core, output_linear
Exemple #4
0
  def testRecurrentNetSamplesWeightsOnce(self):
    """Test that sampling of the weights is done only once for a sequence.

    Test strategy: Provide an input sequence x whose value is the same at each
    time step. If the outputs from f_theta() are the same at each time step,
    this is evidence (but not proof) that theta is the same at each time step.
    """
    seq_length = 10
    batch_size = 1
    input_dim = 5
    output_dim = 5

    bbb_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=bbb.diagonal_gaussian_posterior_builder,
        prior_builder=bbb.fixed_gaussian_prior_builder,
        kl_builder=bbb.stochastic_kl_builder,
        sampling_mode_tensor=tf.constant(bbb.EstimatorModes.sample))

    class NoStateLSTM(snt.LSTM):
      """An LSTM which ignores hidden state."""

      def _build(self, inputs, state):
        outputs, _ = super(NoStateLSTM, self)._build(inputs, state)
        return outputs, state

    with tf.variable_scope("model", custom_getter=bbb_getter):
      core = NoStateLSTM(output_dim)

    input_seq = tf.ones(shape=(seq_length, batch_size, input_dim))
    output_seq, _ = tf.nn.dynamic_rnn(
        core,
        inputs=input_seq,
        initial_state=core.initial_state(batch_size=batch_size),
        time_major=True)

    init_op = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init_op)
      output_res_one = sess.run(output_seq)
      output_res_two = sess.run(output_seq)

    # Ensure that the sequence is the same at every time step, a necessary
    # but not sufficient condition for the weights to be the same.
    output_zero = output_res_one[0]
    for time_step_output in output_res_one[1:]:
      self.assertAllClose(output_zero, time_step_output)

    # Ensure that the noise is different in the second run by checking that
    # the output sequence is different now.
    for first_run_elem, second_run_elem in zip(output_res_one, output_res_two):
      distance = np.linalg.norm(
          first_run_elem.flatten() - second_run_elem.flatten())
      self.assertGreater(distance, 0.001)
Exemple #5
0
  def test_prune_by_bbb_from_scratch_is_correct(self):
    hidden_layers = [256, 128]

    sampling_mode_ph = array_ops.placeholder(dtypes.string, [])
    get_bbb_variable_fn = bayes_by_backprop.bayes_by_backprop_getter(
        posterior_builder=bayes_by_backprop.diagonal_gaussian_posterior_builder,
        prior_builder=bayes_by_backprop.adaptive_gaussian_prior_builder,
        kl_builder=bayes_by_backprop.stochastic_kl_builder,
        sampling_mode_tensor=sampling_mode_ph)

    # create the bayes network
    bbb_scope = 'net'
    inputs_ph = array_ops.placeholder(dtypes.float32, [None, 784])
    with variable_scope.variable_scope(bbb_scope, custom_getter=get_bbb_variable_fn) as vs:
      logits_bbb = mlp(inputs_ph, hidden_layers)

    # create the pruning op
    metadata = bayes_by_backprop.get_variable_metadata()
    total_variables = sum([np.prod(meta.raw_variable_shape) for meta in metadata])
    percentage_ph = array_ops.placeholder(dtypes.float32, [])
    pruned_vars_op = bbb.prune_by_bbb(metadata, percentage_ph)

    # create the template network
    template_scope = 'template'
    with variable_scope.variable_scope(template_scope) as vs:
      logits = mlp(inputs_ph, hidden_layers)
      # retreve the variables so we can prune them in-place
      template_variables = variables.trainable_variables(scope=template_scope)

    # find the variables from 'net' that correspond to 'template'
    assign_pruned_vars_op = bbb.assign_pruned_by_bbb_to_template(
        metadata, pruned_vars_op, template_variables,
        from_scope=bbb_scope, to_scope=template_scope)

    with self.test_session() as sess:
      sess.run(variables.global_variables_initializer())
      for percentage in np.arange(0., 1., .05):
        # ordinary pruning without the test-case looks like this:
        # >>> sess.run(assign_pruned_vars_op, feed_dict={percentage_ph: percentage})

        test_variables, _ = sess.run(
            (pruned_vars_op, assign_pruned_vars_op),
            feed_dict={
              percentage_ph: percentage,
              sampling_mode_ph: bayes_by_backprop.EstimatorModes.mean,
            })

        nonzero = 0
        for test_variable, template_variable in zip(test_variables, template_variables):
          nonzero += np.count_nonzero(test_variable)
          self.assertAllClose(test_variable, sess.run(template_variable), atol=1e-8)
        self.assertAlmostEqual(int(nonzero), int(total_variables * (1 - percentage)), delta=1)
Exemple #6
0
def build_modules(is_training, vocab_size):
  """Construct the modules used in the graph."""

  # Construct the custom getter which implements Bayes by Backprop.
  if is_training:
    estimator_mode = tf.constant(bbb.EstimatorModes.sample)
  else:
    estimator_mode = tf.constant(bbb.EstimatorModes.mean)
  lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter(
      posterior_builder=lstm_posterior_builder,
      prior_builder=custom_scale_mixture_prior_builder,
      kl_builder=bbb.stochastic_kl_builder,
      sampling_mode_tensor=estimator_mode)
  non_lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter(
      posterior_builder=non_lstm_posterior_builder,
      prior_builder=custom_scale_mixture_prior_builder,
      kl_builder=bbb.stochastic_kl_builder,
      sampling_mode_tensor=estimator_mode)

  embed_layer = snt.Embed(
      vocab_size=vocab_size,
      embed_dim=FLAGS.embedding_size,
      custom_getter=non_lstm_bbb_custom_getter,
      name="input_embedding")

  cores = [snt.LSTM(FLAGS.hidden_size,
                    custom_getter=lstm_bbb_custom_getter,
                    forget_bias=0.0,
                    name="lstm_layer_{}".format(i))
           for i in xrange(FLAGS.n_layers)]
  rnn_core = snt.DeepRNN(
      cores,
      skip_connections=False,
      name="deep_lstm_core")

  # Do BBB on weights but not biases of output layer.
  output_linear = snt.Linear(
      vocab_size, custom_getter={"w": non_lstm_bbb_custom_getter})
  return embed_layer, rnn_core, output_linear
Exemple #7
0
  def test_prune_by_bbb_from_scratch_to_sparse_ops(self):
    hidden_layers = [256, 128]

    sampling_mode_ph = array_ops.placeholder(dtypes.string, [])
    get_bbb_variable_fn = bayes_by_backprop.bayes_by_backprop_getter(
        posterior_builder=bayes_by_backprop.diagonal_gaussian_posterior_builder,
        prior_builder=bayes_by_backprop.adaptive_gaussian_prior_builder,
        kl_builder=bayes_by_backprop.stochastic_kl_builder,
        sampling_mode_tensor=sampling_mode_ph)

    # create the bayes network
    bbb_scope = 'net'
    inputs_ph = array_ops.placeholder(dtypes.float32, [None, 784])
    with variable_scope.variable_scope(bbb_scope, custom_getter=get_bbb_variable_fn) as vs:
      logits_bbb = mlp(inputs_ph, hidden_layers)

    # create the optimal pruning op
    metadata = bayes_by_backprop.get_variable_metadata()
    total_variables = sum([np.prod(meta.raw_variable_shape) for meta in metadata])
    percentage_ph = array_ops.placeholder(dtypes.float32, [])
    pruned_vars_op = bbb.prune_by_bbb(metadata, percentage_ph)

    # create the template network
    template_scope = 'template'
    with variable_scope.variable_scope(template_scope) as vs:
      sparse_logits = mlp(inputs_ph, hidden_layers)

      # retreve the variables so we can prune them in-place
      template_variables = variables.trainable_variables(scope=template_scope)

    # find the variables from 'net' that correspond to 'template'
    assign_pruned_vars_op = bbb.assign_pruned_by_bbb_to_template(
        metadata, pruned_vars_op, template_variables,
        from_scope=bbb_scope, to_scope=template_scope)

    with self.test_session() as sess:
      test_inputs_ones = np.ones([1, 784])
      sess.run(variables.global_variables_initializer())
      for percentage in np.arange(0., 1.01, .01):
        sess.run(
            assign_pruned_vars_op, feed_dict={
              percentage_ph: percentage,
              sampling_mode_ph: bayes_by_backprop.EstimatorModes.mean,
            })
        sess.run(sparse_logits, feed_dict={inputs_ph: test_inputs_ones})
Exemple #8
0
  def test_mean_mode_is_deterministic_and_correct(self):
    softplus_of_three = softplus(3.0)

    bbb_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=test_diag_gaussian_builder_builder(10.9, 3.0),
        prior_builder=bbb.fixed_gaussian_prior_builder,
        sampling_mode_tensor=tf.constant(bbb.EstimatorModes.mean))

    with tf.variable_scope("my_scope", custom_getter=bbb_getter):
      my_variable = tf.get_variable("v", shape=[2], dtype=tf.float32)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with self.test_session() as sess:
      sess.run(init_op)
      variable_value_one = sess.run(my_variable)
      variable_value_two = sess.run(my_variable)
      variable_value_three = sess.run(my_variable)
    self.assertAllClose(variable_value_one,
                        np.zeros(shape=[2]) + 10.9,
                        atol=1e-5)
    self.assertAllClose(variable_value_two,
                        np.zeros(shape=[2]) + 10.9,
                        atol=1e-5)
    self.assertAllClose(variable_value_three,
                        np.zeros(shape=[2]) + 10.9,
                        atol=1e-5)

    variable_metadata = bbb.get_variable_metadata()
    self.assertTrue(len(variable_metadata) == 1)
    q_dist_sigma = variable_metadata[0].posterior.scale

    with self.test_session() as sess:
      sigma_res = sess.run(q_dist_sigma)
    self.assertAllClose(sigma_res,
                        np.zeros(shape=[2]) + softplus_of_three,
                        atol=1e-5)
Exemple #9
0
  def testLastSampleMode(self):
    """Tests that the 'last sample' estimator mode uses the last sample."""

    class CustomNormal(tfp.distributions.Normal):
      """A custom normal distribution which implements `self.last_sample()`."""

      def __init__(self, *args, **kwargs):
        super(CustomNormal, self).__init__(*args, **kwargs)
        self._noise = tf.get_variable(
            name=self.loc.name.replace(":", "_") + "_noise",
            shape=self.loc.shape,
            dtype=self.loc.dtype,
            initializer=tf.random_normal_initializer(0.0, 1.0),
            trainable=False)

      def sample(self):
        noise = self._noise.assign(tf.random_normal(self.loc.shape))
        return self.last_sample(noise)

      def last_sample(self, noise=None):
        if noise is None:
          noise = self._noise
        return noise * self.scale + self.loc

    sampling_mode_tensor = tf.get_variable(
        name="sampling_mode",
        dtype=tf.string,
        shape=(),
        trainable=False,
        initializer=tf.constant_initializer(bbb.EstimatorModes.sample))
    enter_last_sample_mode = tf.assign(
        sampling_mode_tensor, tf.constant(bbb.EstimatorModes.last_sample))
    bbb_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=test_diag_gaussian_builder_builder(
            dist_cls=CustomNormal),
        prior_builder=bbb.adaptive_gaussian_prior_builder,
        kl_builder=bbb.stochastic_kl_builder,
        sampling_mode_tensor=sampling_mode_tensor)
    with tf.variable_scope("model_scope", custom_getter=bbb_getter):
      model = snt.Linear(5)

    data = tf.placeholder(shape=(2, 4), dtype=tf.float32)
    outputs = model(data)

    # We expect there to be 8 trainable variables.
    # model (Linear has two variables: weight and bias).
    # The posterior has two variables (mu and sigma) for each variable.
    # So does the prior (since it's adaptive).
    self.assertEqual(len(tf.trainable_variables()), 2*2*2)

    init_op = tf.global_variables_initializer()
    x_feed = np.random.normal(size=(2, 4))
    with self.test_session() as sess:
      sess.run(init_op)
      output_res_one = sess.run(outputs, feed_dict={data: x_feed})
      output_res_two = sess.run(outputs, feed_dict={data: x_feed})
      sess.run(enter_last_sample_mode)
      output_res_three = sess.run(outputs, feed_dict={data: x_feed})
      output_res_four = sess.run(outputs, feed_dict={data: x_feed})

    # One and two should be different samples.
    self.assertTrue((output_res_one != output_res_two).all())
    # Two through four should be the same.
    self.assertAllClose(output_res_two, output_res_three)
    self.assertAllClose(output_res_three, output_res_four)
    self.assertAllClose(output_res_two, output_res_four)
Exemple #10
0
  def test_variable_sharing(self):
    _, x_size = input_shape = [5, 5]

    sample_mode = tf.constant(bbb.EstimatorModes.sample)
    mean_mode = tf.constant(bbb.EstimatorModes.mean)
    sampling_mode = tf.get_variable(
        "bbb_sampling_mode",
        initializer=tf.constant_initializer(bbb.EstimatorModes.sample),
        dtype=tf.string,
        shape=(),
        trainable=False)
    set_to_sample_mode = tf.assign(sampling_mode, sample_mode)
    set_to_mean_mode = tf.assign(sampling_mode, mean_mode)

    bbb_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=bbb.diagonal_gaussian_posterior_builder,
        prior_builder=bbb.fixed_gaussian_prior_builder,
        kl_builder=bbb.stochastic_kl_builder,
        sampling_mode_tensor=sampling_mode)

    tf.get_variable_scope().set_custom_getter(bbb_getter)
    mlp = snt.nets.MLP(output_sizes=[32, x_size])
    x_train = tf.placeholder(dtype=tf.float32, shape=input_shape)
    x_test = tf.placeholder(dtype=tf.float32, shape=input_shape)

    # Dummy targets.
    target_train = x_train + 3.0
    target_test = x_test + 3.0

    y_train = mlp(x_train)

    # Also, y_test should be deterministic for fixed x.
    y_test = mlp(x_test)

    # Expect there to be two parameter for w and b for each layer in the MLP,
    #. That's 2 * 2 * 2 = 8. But ONLY for the training set.
    expected_number_of_variables = 8
    actual_number_of_variables = len(tf.trainable_variables())
    self.assertTrue(expected_number_of_variables == actual_number_of_variables)

    loss_train = tf.reduce_sum(tf.square(y_train - target_train),
                               reduction_indices=[1])
    loss_train = tf.reduce_mean(loss_train, reduction_indices=[0])
    loss_test = tf.reduce_sum(tf.square(y_test - target_test),
                              reduction_indices=[1])
    loss_test = tf.reduce_mean(loss_test)

    kl_cost = bbb.get_total_kl_cost() * 0.000001
    total_train_loss = loss_train + kl_cost
    optimizer = tf.train.GradientDescentOptimizer(0.001)
    train_step = optimizer.minimize(total_train_loss)

    x_feed = np.random.normal(size=input_shape)
    fd = {
        x_train: x_feed,
        x_test: x_feed
    }

    init_op = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init_op)
      sess.run(set_to_mean_mode)
      y_test_res_one = sess.run(y_test, feed_dict=fd)
      y_test_res_two = sess.run(y_test, feed_dict=fd)
      sess.run(set_to_sample_mode)
    self.assertAllClose(y_test_res_one, y_test_res_two)

    n_train = 10
    check_freq = 2
    with self.test_session() as sess:
      for i in xrange(n_train):
        if i % check_freq == 0:
          sess.run(set_to_mean_mode)
          to_run = [y_train, y_test, loss_train, loss_test, kl_cost]
        else:
          to_run = [y_train, y_test, loss_train, loss_test, kl_cost, train_step]
        res = sess.run(to_run, feed_dict=fd)
        loss_train_res, loss_test_res = res[2:4]

        if i % check_freq == 0:
          self.assertAllClose(loss_train_res, loss_test_res)
          sess.run(set_to_sample_mode)
Exemple #11
0
  def test_sample_mode_is_stochastic_and_can_be_switched(self):
    use_mean = tf.constant(bbb.EstimatorModes.mean)
    use_sample = tf.constant(bbb.EstimatorModes.sample)
    sampling_mode = tf.get_variable(
        "bbb_sampling_mode",
        initializer=tf.constant_initializer(bbb.EstimatorModes.sample),
        dtype=tf.string,
        shape=(),
        trainable=False)
    set_to_mean_mode = tf.assign(sampling_mode, use_mean)
    set_to_sample_mode = tf.assign(sampling_mode, use_sample)

    softplus_of_twenty = softplus(20.0)
    bbb_getter = bbb.bayes_by_backprop_getter(
        posterior_builder=test_diag_gaussian_builder_builder(10.9, 20.0),
        prior_builder=bbb.fixed_gaussian_prior_builder,
        sampling_mode_tensor=sampling_mode)

    with tf.variable_scope("my_scope", custom_getter=bbb_getter):
      my_variable = tf.get_variable("v", shape=[10, 3], dtype=tf.float32)

    # Check that the distribution has the right parameters.
    variable_metadata = bbb.get_variable_metadata()
    self.assertTrue(len(variable_metadata) == 1)
    q_dist_mean = variable_metadata[0].posterior.loc
    q_dist_sigma = variable_metadata[0].posterior.scale

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    with self.test_session() as sess:
      sess.run(init_op)
      mean_res, sigma_res = sess.run([q_dist_mean, q_dist_sigma])
      variable_value_one = sess.run(my_variable)
      variable_value_two = sess.run(my_variable)
    self.assertAllClose(mean_res, np.zeros(shape=[10, 3])+10.9)
    self.assertAllClose(sigma_res, np.zeros(shape=[10, 3]) + softplus_of_twenty)

    actual_distance = np.sqrt(
        np.sum(np.square(variable_value_one - variable_value_two)))
    expected_distance_minimum = 5
    self.assertGreater(actual_distance, expected_distance_minimum)

    # Now the value should be deterministic again.
    with self.test_session() as sess:
      sess.run(set_to_mean_mode)
      variable_value_three = sess.run(my_variable)
      variable_value_four = sess.run(my_variable)
      variable_value_five = sess.run(my_variable)
    self.assertAllClose(variable_value_three,
                        np.zeros(shape=[10, 3]) + 10.9,
                        atol=1e-5)
    self.assertAllClose(variable_value_four,
                        np.zeros(shape=[10, 3]) + 10.9,
                        atol=1e-5)
    self.assertAllClose(variable_value_five,
                        np.zeros(shape=[10, 3]) + 10.9,
                        atol=1e-5)

    # Now it should be stochastic again.
    with self.test_session() as sess:
      sess.run(set_to_sample_mode)
      variable_value_six = sess.run(my_variable)
      variable_value_seven = sess.run(my_variable)
    actual_new_distance = np.sqrt(
        np.sum(np.square(variable_value_six - variable_value_seven)))
    self.assertGreater(actual_new_distance, expected_distance_minimum)
                       initializer=tf.random_uniform(
                           minval=np.log(np.exp(prior_stddev / 2.0) - 1.0),
                           maxval=np.log(np.exp(prior_stddev / 1.0) - 1.0),
                           dtype=tf.float32,
                           shape=parameter_shapes["scale"]))
    return tf.contrib.distributions.Normal(
        loc=loc_var,
        scale=tf.nn.softplus(scale_var) + 1e-5,
        name="{}/posterior_dist".format(name))


estimator_mode = tf.constant(bbb.EstimatorModes.sample)

lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter(
    posterior_builder=lstm_posterior_builder,
    prior_builder=custom_scale_mixture_prior_builder,
    kl_builder=bbb.stochastic_kl_builder,
    sampling_mode_tensor=estimator_mode)

non_lstm_bbb_custom_getter = bbb.bayes_by_backprop_getter(
    posterior_builder=non_lstm_posterior_builder,
    prior_builder=custom_scale_mixture_prior_builder,
    kl_builder=bbb.stochastic_kl_builder,
    sampling_mode_tensor=estimator_mode)


class Linear(base.AbstractModule):
    """Linear module, optionally including bias."""
    def __init__(self,
                 output_size,
                 use_bias=True,